// Copyright (C) 2013  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.

#include <dlib/python.h>
#include <dlib/matrix.h>
#include <boost/python/args.hpp>
#include <dlib/svm.h>


using namespace dlib;
using namespace std;
using namespace boost::python;

template <typename psi_type>
class svm_struct_prob : public structural_svm_problem<matrix<double,0,1>, psi_type>
{
    typedef structural_svm_problem<matrix<double,0,1>, psi_type> base;
    typedef typename base::feature_vector_type feature_vector_type;
    typedef typename base::matrix_type matrix_type;
    typedef typename base::scalar_type scalar_type;
public:
    svm_struct_prob (
        object& problem_,
        long num_dimensions_,
        long num_samples_
    ) : 
        num_dimensions(num_dimensions_),
        num_samples(num_samples_),
        problem(problem_)
    {} 

    virtual long get_num_dimensions (
    ) const { return num_dimensions; }

    virtual long get_num_samples (
    ) const { return num_samples; }

    virtual void get_truth_joint_feature_vector (
        long idx,
        feature_vector_type& psi 
    ) const 
    {
        psi = extract<feature_vector_type&>(problem.attr("get_truth_joint_feature_vector")(idx));
    }

    virtual void separation_oracle (
        const long idx,
        const matrix_type& current_solution,
        scalar_type& loss,
        feature_vector_type& psi
    ) const 
    {
        object res = problem.attr("separation_oracle")(idx,boost::ref(current_solution));
        pyassert(len(res) == 2, "separation_oracle() must return two objects, the loss and the psi vector");
        // let the user supply the output arguments in any order.
        if (extract<double>(res[0]).check())
        {
            loss = extract<double>(res[0]);
            psi = extract<feature_vector_type&>(res[1]);
        }
        else
        {
            psi = extract<feature_vector_type&>(res[0]);
            loss = extract<double>(res[1]);
        }
    }

private:

    const long num_dimensions;
    const long num_samples;
    object& problem;
};

// ----------------------------------------------------------------------------------------

template <typename psi_type>
matrix<double,0,1> solve_structural_svm_problem_impl(
    object problem
)
{
    const double C = extract<double>(problem.attr("C"));
    const bool be_verbose = hasattr(problem,"be_verbose") && extract<bool>(problem.attr("be_verbose"));
    const bool use_sparse_feature_vectors = hasattr(problem,"use_sparse_feature_vectors") && 
                                            extract<bool>(problem.attr("use_sparse_feature_vectors"));
    const bool learns_nonnegative_weights = hasattr(problem,"learns_nonnegative_weights") && 
                                            extract<bool>(problem.attr("learns_nonnegative_weights"));

    double eps = 0.001;
    unsigned long max_cache_size = 10;
    if (hasattr(problem, "epsilon"))
        eps = extract<double>(problem.attr("epsilon"));
    if (hasattr(problem, "max_cache_size"))
        max_cache_size = extract<double>(problem.attr("max_cache_size"));

    const long num_samples = extract<long>(problem.attr("num_samples"));
    const long num_dimensions = extract<long>(problem.attr("num_dimensions"));

    pyassert(num_samples > 0, "You can't train a Structural-SVM if you don't have any training samples.");

    if (be_verbose)
    {
        cout << "C:              " << C << endl;
        cout << "epsilon:        " << eps << endl;
        cout << "max_cache_size: " << max_cache_size << endl;
        cout << "num_samples:    " << num_samples << endl;
        cout << "num_dimensions: " << num_dimensions << endl;
        cout << "use_sparse_feature_vectors: " << std::boolalpha << use_sparse_feature_vectors << endl;
        cout << "learns_nonnegative_weights: " << std::boolalpha << learns_nonnegative_weights << endl;
        cout << endl;
    }

    svm_struct_prob<psi_type> prob(problem, num_dimensions, num_samples);
    prob.set_c(C);
    prob.set_epsilon(eps);
    prob.set_max_cache_size(max_cache_size);
    if (be_verbose)
        prob.be_verbose();

    oca solver;
    matrix<double,0,1> w;
    if (learns_nonnegative_weights)
        solver(prob, w, prob.get_num_dimensions());
    else
        solver(prob, w);
    return w;
}

// ----------------------------------------------------------------------------------------

matrix<double,0,1> solve_structural_svm_problem(
    object problem
)
{
    // Check if the python code is using sparse or dense vectors to represent PSI()
    extract<matrix<double,0,1> > isdense(problem.attr("get_truth_joint_feature_vector")(0));
    if (isdense.check())
        return solve_structural_svm_problem_impl<matrix<double,0,1> >(problem);
    else
        return solve_structural_svm_problem_impl<std::vector<std::pair<unsigned long,double> > >(problem);
}

// ----------------------------------------------------------------------------------------

void bind_svm_struct()
{
    using boost::python::arg;

    def("solve_structural_svm_problem",solve_structural_svm_problem, (arg("problem")),
"This function solves a structural SVM problem and returns the weight vector    \n\
that defines the solution.  See the example program python_examples/svm_struct.py    \n\
for documentation about how to create a proper problem object.   " 
        );
}

// ----------------------------------------------------------------------------------------