// Copyright (C) 2013  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.

#include <boost/python.hpp>
#include <dlib/matrix.h>
#include <boost/python/args.hpp>
#include "pyassert.h"
#include "boost_python_utils.h"
#include <dlib/svm.h>


using namespace dlib;
using namespace std;
using namespace boost::python;

class svm_struct_dense : public structural_svm_problem<matrix<double,0,1> >
{
public:
    svm_struct_dense (
        object& problem_,
        long num_dimensions_,
        long num_samples_
    ) : 
        num_dimensions(num_dimensions_),
        num_samples(num_samples_),
        problem(problem_)
    {} 

    virtual long get_num_dimensions (
    ) const { return num_dimensions; }

    virtual long get_num_samples (
    ) const { return num_samples; }

    virtual void get_truth_joint_feature_vector (
        long idx,
        feature_vector_type& psi 
    ) const 
    {
        psi = extract<feature_vector_type&>(problem.attr("get_truth_joint_feature_vector")(idx));
    }

    virtual void separation_oracle (
        const long idx,
        const matrix_type& current_solution,
        scalar_type& loss,
        feature_vector_type& psi
    ) const 
    {
        object res = problem.attr("separation_oracle")(idx,boost::ref(current_solution));
        pyassert(len(res) == 2, "separation_oracle() must return two objects, the loss and the psi vector");
        // let the user supply the output arguments in any order.
        if (extract<double>(res[0]).check())
        {
            loss = extract<double>(res[0]);
            psi = extract<feature_vector_type&>(res[1]);
        }
        else
        {
            psi = extract<feature_vector_type&>(res[0]);
            loss = extract<double>(res[1]);
        }
    }

private:

    const long num_dimensions;
    const long num_samples;
    object& problem;
};

// ----------------------------------------------------------------------------------------

/*
class svm_struct_sparse : public structural_svm_problem<matrix<double,0,1>, 
                                    std::vector<std::pair<unsigned long,double> >
{
};
*/

// ----------------------------------------------------------------------------------------

matrix<double,0,1> solve_structural_svm_problem(
    object problem
)
{
    const double C = extract<double>(problem.attr("C"));
    const bool be_verbose = hasattr(problem,"be_verbose") && extract<bool>(problem.attr("be_verbose"));
    const bool use_sparse_feature_vectors = hasattr(problem,"use_sparse_feature_vectors") && 
                                            extract<bool>(problem.attr("use_sparse_feature_vectors"));
    const bool learns_nonnegative_weights = hasattr(problem,"learns_nonnegative_weights") && 
                                            extract<bool>(problem.attr("learns_nonnegative_weights"));

    double eps = 0.001;
    unsigned long max_cache_size = 10;
    if (hasattr(problem, "epsilon"))
        eps = extract<double>(problem.attr("epsilon"));
    if (hasattr(problem, "max_cache_size"))
        max_cache_size = extract<double>(problem.attr("max_cache_size"));

    const long num_samples = extract<long>(problem.attr("num_samples"));
    const long num_dimensions = extract<long>(problem.attr("num_dimensions"));

    if (be_verbose)
    {
        cout << "C:              " << C << endl;
        cout << "epsilon:        " << eps << endl;
        cout << "max_cache_size: " << max_cache_size << endl;
        cout << "num_samples:    " << num_samples << endl;
        cout << "num_dimensions: " << num_dimensions << endl;
        cout << "use_sparse_feature_vectors: " << std::boolalpha << use_sparse_feature_vectors << endl;
        cout << "learns_nonnegative_weights: " << std::boolalpha << learns_nonnegative_weights << endl;
        cout << endl;
    }


    svm_struct_dense prob(problem, num_dimensions, num_samples);
    prob.set_c(C);
    prob.set_epsilon(eps);
    prob.set_max_cache_size(max_cache_size);
    if (be_verbose)
        prob.be_verbose();

    oca solver;
    matrix<double,0,1> w;
    if (learns_nonnegative_weights)
        solver(prob, w, prob.get_num_dimensions());
    else
        solver(prob, w);
    return w;
}

// ----------------------------------------------------------------------------------------

void bind_svm_struct()
{
    using boost::python::arg;

    def("solve_structural_svm_problem",solve_structural_svm_problem, (arg("problem")),
"This function solves a structural SVM problem and returns the weight vector    \n\
that defines the solution.  See the example program python_examples/svm_struct.py    \n\
for documentation about how to create a proper problem object.   " 
        );
}

// ----------------------------------------------------------------------------------------