Commit da40c3ba authored by Davis King's avatar Davis King

cleaned up python interface a bit

parent e1267e0e
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <boost/python.hpp> #include <boost/python.hpp>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include "serialize_pickle.h" #include "serialize_pickle.h"
#include <boost/python/args.hpp>
#include <dlib/svm.h> #include <dlib/svm.h>
using namespace dlib; using namespace dlib;
...@@ -156,6 +157,7 @@ ranking_test _test_ranking_function2 ( ...@@ -156,6 +157,7 @@ ranking_test _test_ranking_function2 (
void bind_decision_functions() void bind_decision_functions()
{ {
using boost::python::arg;
add_linear_df<linear_kernel<sample_type> >("_decision_function_linear"); add_linear_df<linear_kernel<sample_type> >("_decision_function_linear");
add_linear_df<sparse_linear_kernel<sparse_vect> >("_decision_function_sparse_linear"); add_linear_df<sparse_linear_kernel<sparse_vect> >("_decision_function_sparse_linear");
...@@ -172,51 +174,82 @@ void bind_decision_functions() ...@@ -172,51 +174,82 @@ void bind_decision_functions()
add_df<sparse_sigmoid_kernel<sparse_vect> >("_decision_function_sparse_sigmoid"); add_df<sparse_sigmoid_kernel<sparse_vect> >("_decision_function_sparse_sigmoid");
def("test_binary_decision_function", _test_binary_decision_function<linear_kernel<sample_type> >); def("test_binary_decision_function", _test_binary_decision_function<linear_kernel<sample_type> >,
def("test_binary_decision_function", _test_binary_decision_function<sparse_linear_kernel<sparse_vect> >); (arg("function"), arg("samples"), arg("labels")));
def("test_binary_decision_function", _test_binary_decision_function<radial_basis_kernel<sample_type> >); def("test_binary_decision_function", _test_binary_decision_function<sparse_linear_kernel<sparse_vect> >,
def("test_binary_decision_function", _test_binary_decision_function<sparse_radial_basis_kernel<sparse_vect> >); (arg("function"), arg("samples"), arg("labels")));
def("test_binary_decision_function", _test_binary_decision_function<polynomial_kernel<sample_type> >); def("test_binary_decision_function", _test_binary_decision_function<radial_basis_kernel<sample_type> >,
def("test_binary_decision_function", _test_binary_decision_function<sparse_polynomial_kernel<sparse_vect> >); (arg("function"), arg("samples"), arg("labels")));
def("test_binary_decision_function", _test_binary_decision_function<histogram_intersection_kernel<sample_type> >); def("test_binary_decision_function", _test_binary_decision_function<sparse_radial_basis_kernel<sparse_vect> >,
def("test_binary_decision_function", _test_binary_decision_function<sparse_histogram_intersection_kernel<sparse_vect> >); (arg("function"), arg("samples"), arg("labels")));
def("test_binary_decision_function", _test_binary_decision_function<sigmoid_kernel<sample_type> >); def("test_binary_decision_function", _test_binary_decision_function<polynomial_kernel<sample_type> >,
def("test_binary_decision_function", _test_binary_decision_function<sparse_sigmoid_kernel<sparse_vect> >); (arg("function"), arg("samples"), arg("labels")));
def("test_binary_decision_function", _test_binary_decision_function<sparse_polynomial_kernel<sparse_vect> >,
def("test_regression_function", _test_regression_function<linear_kernel<sample_type> >); (arg("function"), arg("samples"), arg("labels")));
def("test_regression_function", _test_regression_function<sparse_linear_kernel<sparse_vect> >); def("test_binary_decision_function", _test_binary_decision_function<histogram_intersection_kernel<sample_type> >,
def("test_regression_function", _test_regression_function<radial_basis_kernel<sample_type> >); (arg("function"), arg("samples"), arg("labels")));
def("test_regression_function", _test_regression_function<sparse_radial_basis_kernel<sparse_vect> >); def("test_binary_decision_function", _test_binary_decision_function<sparse_histogram_intersection_kernel<sparse_vect> >,
def("test_regression_function", _test_regression_function<histogram_intersection_kernel<sample_type> >); (arg("function"), arg("samples"), arg("labels")));
def("test_regression_function", _test_regression_function<sparse_histogram_intersection_kernel<sparse_vect> >); def("test_binary_decision_function", _test_binary_decision_function<sigmoid_kernel<sample_type> >,
def("test_regression_function", _test_regression_function<sigmoid_kernel<sample_type> >); (arg("function"), arg("samples"), arg("labels")));
def("test_regression_function", _test_regression_function<sparse_sigmoid_kernel<sparse_vect> >); def("test_binary_decision_function", _test_binary_decision_function<sparse_sigmoid_kernel<sparse_vect> >,
def("test_regression_function", _test_regression_function<polynomial_kernel<sample_type> >); (arg("function"), arg("samples"), arg("labels")));
def("test_regression_function", _test_regression_function<sparse_polynomial_kernel<sparse_vect> >);
def("test_regression_function", _test_regression_function<linear_kernel<sample_type> >,
def("test_ranking_function", _test_ranking_function1<linear_kernel<sample_type> >); (arg("function"), arg("samples"), arg("targets")));
def("test_ranking_function", _test_ranking_function1<sparse_linear_kernel<sparse_vect> >); def("test_regression_function", _test_regression_function<sparse_linear_kernel<sparse_vect> >,
def("test_ranking_function", _test_ranking_function2<linear_kernel<sample_type> >); (arg("function"), arg("samples"), arg("targets")));
def("test_ranking_function", _test_ranking_function2<sparse_linear_kernel<sparse_vect> >); def("test_regression_function", _test_regression_function<radial_basis_kernel<sample_type> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_regression_function", _test_regression_function<sparse_radial_basis_kernel<sparse_vect> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_regression_function", _test_regression_function<histogram_intersection_kernel<sample_type> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_regression_function", _test_regression_function<sparse_histogram_intersection_kernel<sparse_vect> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_regression_function", _test_regression_function<sigmoid_kernel<sample_type> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_regression_function", _test_regression_function<sparse_sigmoid_kernel<sparse_vect> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_regression_function", _test_regression_function<polynomial_kernel<sample_type> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_regression_function", _test_regression_function<sparse_polynomial_kernel<sparse_vect> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_ranking_function", _test_ranking_function1<linear_kernel<sample_type> >,
(arg("function"), arg("samples")));
def("test_ranking_function", _test_ranking_function1<sparse_linear_kernel<sparse_vect> >,
(arg("function"), arg("samples")));
def("test_ranking_function", _test_ranking_function2<linear_kernel<sample_type> >,
(arg("function"), arg("sample")));
def("test_ranking_function", _test_ranking_function2<sparse_linear_kernel<sparse_vect> >,
(arg("function"), arg("sample")));
class_<binary_test>("_binary_test") class_<binary_test>("_binary_test")
.add_property("class1_accuracy", &binary_test::class1_accuracy)
.def("__str__", binary_test__str__) .def("__str__", binary_test__str__)
.def("__repr__", binary_test__repr__) .def("__repr__", binary_test__repr__)
.add_property("class2_accuracy", &binary_test::class2_accuracy); .add_property("class1_accuracy", &binary_test::class1_accuracy,
"A value between 0 and 1, measures accuracy on the +1 class.")
.add_property("class2_accuracy", &binary_test::class2_accuracy,
"A value between 0 and 1, measures accuracy on the -1 class.");
class_<ranking_test>("_ranking_test") class_<ranking_test>("_ranking_test")
.add_property("ranking_accuracy", &ranking_test::ranking_accuracy)
.def("__str__", ranking_test__str__) .def("__str__", ranking_test__str__)
.def("__repr__", ranking_test__repr__) .def("__repr__", ranking_test__repr__)
.add_property("mean_ap", &ranking_test::mean_ap); .add_property("ranking_accuracy", &ranking_test::ranking_accuracy,
"A value between 0 and 1, measures the fraction of times a relevant sample was ordered before a non-relevant sample.")
.add_property("mean_ap", &ranking_test::mean_ap,
"A value between 0 and 1, measures the mean average precision of the ranking.");
class_<regression_test>("_regression_test") class_<regression_test>("_regression_test")
.add_property("mean_squared_error", &regression_test::mean_squared_error)
.def("__str__", regression_test__str__) .def("__str__", regression_test__str__)
.def("__repr__", regression_test__repr__) .def("__repr__", regression_test__repr__)
.add_property("R_squared", &regression_test::R_squared); .add_property("mean_squared_error", &regression_test::mean_squared_error,
"The mean squared error of a regression function on a dataset.")
.add_property("R_squared", &regression_test::R_squared,
"A value between 0 and 1, measures the squared correlation between the output of a \n"
"regression function and the target values.");
} }
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#include <dlib/matrix.h> #include <dlib/matrix.h>
#include <dlib/data_io.h> #include <dlib/data_io.h>
#include <dlib/sparse_vector.h> #include <dlib/sparse_vector.h>
#include <boost/python/args.hpp>
#include "pyassert.h"
using namespace dlib; using namespace dlib;
using namespace std; using namespace std;
...@@ -10,27 +12,6 @@ using namespace boost::python; ...@@ -10,27 +12,6 @@ using namespace boost::python;
typedef std::vector<std::pair<unsigned long,double> > sparse_vect; typedef std::vector<std::pair<unsigned long,double> > sparse_vect;
tuple get_training_data()
{
typedef matrix<double,0,1> sample_type;
std::vector<sample_type> samples;
std::vector<double> labels;
sample_type samp(3);
for (int i = 0; i < 10; ++i)
{
samp = 1,2,3;
samples.push_back(samp);
labels.push_back(+1);
samp = -1,-2,-3;
samples.push_back(samp);
labels.push_back(-1);
}
return make_tuple(samples, labels);
}
void _make_sparse_vector ( void _make_sparse_vector (
sparse_vect& v sparse_vect& v
...@@ -61,14 +42,34 @@ void _save_libsvm_formatted_data ( ...@@ -61,14 +42,34 @@ void _save_libsvm_formatted_data (
const std::string& file_name, const std::string& file_name,
const std::vector<sparse_vect>& samples, const std::vector<sparse_vect>& samples,
const std::vector<double>& labels const std::vector<double>& labels
) { save_libsvm_formatted_data(file_name, samples, labels); } )
{
pyassert(samples.size() == labels.size(), "Invalid inputs");
save_libsvm_formatted_data(file_name, samples, labels);
}
void bind_other() void bind_other()
{ {
def("get_training_data",get_training_data); using boost::python::arg;
def("make_sparse_vector", _make_sparse_vector , "This function modifies its argument so that it is a properly sorted sparse vector.");
def("make_sparse_vector", _make_sparse_vector2 , "This function modifies a sparse_vectors object so that all elements it contains are properly sorted sparse vectors."); def("make_sparse_vector", _make_sparse_vector ,
def("load_libsvm_formatted_data",_load_libsvm_formatted_data); "This function modifies its argument so that it is a properly sorted sparse vector.");
def("save_libsvm_formatted_data",_save_libsvm_formatted_data); def("make_sparse_vector", _make_sparse_vector2 ,
"This function modifies a sparse_vectors object so that all elements it contains are properly sorted sparse vectors.");
def("load_libsvm_formatted_data",_load_libsvm_formatted_data, (arg("file_name")),
"ensures \n\
- Attempts to read a file of the given name that should contain libsvm \n\
formatted data. The data is returned as a tuple where the first tuple \n\
element is an array of sparse vectors and the second element is an array of \n\
labels. "
);
def("save_libsvm_formatted_data",_save_libsvm_formatted_data, (arg("file_name"), arg("samples"), arg("labels")),
"requires \n\
- len(samples) == len(labels) \n\
ensures \n\
- saves the data to the given file in libsvm format "
);
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment