1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#include <boost/python.hpp>
#include <boost/shared_ptr.hpp>
#include <dlib/matrix.h>
#include "serialize_pickle.h"
#include <dlib/svm.h>
#include "pyassert.h"
#include <boost/python/suite/indexing/vector_indexing_suite.hpp>
using namespace dlib;
using namespace std;
using namespace boost::python;
typedef matrix<double,0,1> sample_type;
typedef std::vector<std::pair<unsigned long,double> > sparse_vect;
// ----------------------------------------------------------------------------------------
namespace dlib
{
template <typename T>
bool operator== (
const ranking_pair<T>& ,
const ranking_pair<T>&
)
{
pyassert(false, "It is illegal to compare ranking pair objects for equality.");
return false;
}
}
template <typename T>
void resize(T& v, unsigned long n) { v.resize(n); }
// ----------------------------------------------------------------------------------------
template <typename trainer_type>
typename trainer_type::trained_function_type train1 (
const trainer_type& trainer,
const ranking_pair<typename trainer_type::sample_type>& sample
)
{
typedef ranking_pair<typename trainer_type::sample_type> st;
pyassert(is_ranking_problem(std::vector<st>(1, sample)), "Invalid inputs");
return trainer.train(sample);
}
template <typename trainer_type>
typename trainer_type::trained_function_type train2 (
const trainer_type& trainer,
const std::vector<ranking_pair<typename trainer_type::sample_type> >& samples
)
{
pyassert(is_ranking_problem(samples), "Invalid inputs");
return trainer.train(samples);
}
template <typename trainer_type>
void set_epsilon ( trainer_type& trainer, double eps)
{
pyassert(eps > 0, "epsilon must be > 0");
trainer.set_epsilon(eps);
}
template <typename trainer_type>
double get_epsilon ( const trainer_type& trainer) { return trainer.get_epsilon(); }
template <typename trainer_type>
void set_c ( trainer_type& trainer, double C)
{
pyassert(C > 0, "C must be > 0");
trainer.set_c(C);
}
template <typename trainer_type>
double get_c (const trainer_type& trainer)
{
return trainer.get_c();
}
template <typename trainer>
void add_ranker (
const char* name
)
{
class_<trainer>(name)
.add_property("epsilon", get_epsilon<trainer>, set_epsilon<trainer>)
.add_property("c", get_c<trainer>, set_c<trainer>)
.add_property("max_iterations", &trainer::get_max_iterations, &trainer::set_max_iterations)
.add_property("force_last_weight_to_1", &trainer::forces_last_weight_to_1, &trainer::force_last_weight_to_1)
.add_property("learns_nonnegative_weights", &trainer::learns_nonnegative_weights, &trainer::set_learns_nonnegative_weights)
.def("train", train1<trainer>)
.def("train", train2<trainer>)
.def("be_verbose", &trainer::be_verbose)
.def("be_quiet", &trainer::be_quiet);
}
// ----------------------------------------------------------------------------------------
void bind_svm_rank_trainer()
{
class_<ranking_pair<sample_type> >("ranking_pair")
.add_property("relevant", &ranking_pair<sample_type>::relevant)
.add_property("nonrelevant", &ranking_pair<sample_type>::nonrelevant)
.def_pickle(serialize_pickle<ranking_pair<sample_type> >());
class_<ranking_pair<sparse_vect> >("sparse_ranking_pair")
.add_property("relevant", &ranking_pair<sparse_vect>::relevant)
.add_property("nonrelevant", &ranking_pair<sparse_vect>::nonrelevant)
.def_pickle(serialize_pickle<ranking_pair<sparse_vect> >());
typedef std::vector<ranking_pair<sample_type> > ranking_pairs;
class_<ranking_pairs>("ranking_pairs")
.def(vector_indexing_suite<ranking_pairs>())
.def("clear", &ranking_pairs::clear)
.def("resize", resize<ranking_pairs>)
.def_pickle(serialize_pickle<ranking_pairs>());
typedef std::vector<ranking_pair<sparse_vect> > sparse_ranking_pairs;
class_<sparse_ranking_pairs>("sparse_ranking_pairs")
.def(vector_indexing_suite<sparse_ranking_pairs>())
.def("clear", &sparse_ranking_pairs::clear)
.def("resize", resize<sparse_ranking_pairs>)
.def_pickle(serialize_pickle<sparse_ranking_pairs>());
add_ranker<svm_rank_trainer<linear_kernel<sample_type> > >("svm_rank_trainer");
add_ranker<svm_rank_trainer<sparse_linear_kernel<sparse_vect> > >("svm_rank_trainer_sparse");
}