#include #include #include #include "serialize_pickle.h" #include #include "pyassert.h" #include using namespace dlib; using namespace std; using namespace boost::python; typedef matrix sample_type; typedef std::vector > sparse_vect; // ---------------------------------------------------------------------------------------- namespace dlib { template bool operator== ( const ranking_pair& , const ranking_pair& ) { pyassert(false, "It is illegal to compare ranking pair objects for equality."); return false; } } template void resize(T& v, unsigned long n) { v.resize(n); } // ---------------------------------------------------------------------------------------- template typename trainer_type::trained_function_type train1 ( const trainer_type& trainer, const ranking_pair& sample ) { typedef ranking_pair st; pyassert(is_ranking_problem(std::vector(1, sample)), "Invalid inputs"); return trainer.train(sample); } template typename trainer_type::trained_function_type train2 ( const trainer_type& trainer, const std::vector >& samples ) { pyassert(is_ranking_problem(samples), "Invalid inputs"); return trainer.train(samples); } template void set_epsilon ( trainer_type& trainer, double eps) { pyassert(eps > 0, "epsilon must be > 0"); trainer.set_epsilon(eps); } template double get_epsilon ( const trainer_type& trainer) { return trainer.get_epsilon(); } template void set_c ( trainer_type& trainer, double C) { pyassert(C > 0, "C must be > 0"); trainer.set_c(C); } template double get_c (const trainer_type& trainer) { return trainer.get_c(); } template void add_ranker ( const char* name ) { class_(name) .add_property("epsilon", get_epsilon, set_epsilon) .add_property("c", get_c, set_c) .add_property("max_iterations", &trainer::get_max_iterations, &trainer::set_max_iterations) .add_property("force_last_weight_to_1", &trainer::forces_last_weight_to_1, &trainer::force_last_weight_to_1) .add_property("learns_nonnegative_weights", &trainer::learns_nonnegative_weights, &trainer::set_learns_nonnegative_weights) .def("train", train1) .def("train", train2) .def("be_verbose", &trainer::be_verbose) .def("be_quiet", &trainer::be_quiet); } // ---------------------------------------------------------------------------------------- void bind_svm_rank_trainer() { class_ >("ranking_pair") .add_property("relevant", &ranking_pair::relevant) .add_property("nonrelevant", &ranking_pair::nonrelevant) .def_pickle(serialize_pickle >()); class_ >("sparse_ranking_pair") .add_property("relevant", &ranking_pair::relevant) .add_property("nonrelevant", &ranking_pair::nonrelevant) .def_pickle(serialize_pickle >()); typedef std::vector > ranking_pairs; class_("ranking_pairs") .def(vector_indexing_suite()) .def("clear", &ranking_pairs::clear) .def("resize", resize) .def_pickle(serialize_pickle()); typedef std::vector > sparse_ranking_pairs; class_("sparse_ranking_pairs") .def(vector_indexing_suite()) .def("clear", &sparse_ranking_pairs::clear) .def("resize", resize) .def_pickle(serialize_pickle()); add_ranker > >("svm_rank_trainer"); add_ranker > >("svm_rank_trainer_sparse"); }