Commit da40c3ba authored by Davis King's avatar Davis King

cleaned up python interface a bit

parent e1267e0e
......@@ -3,6 +3,7 @@
#include <boost/python.hpp>
#include <boost/shared_ptr.hpp>
#include "serialize_pickle.h"
#include <boost/python/args.hpp>
#include <dlib/svm.h>
using namespace dlib;
......@@ -156,6 +157,7 @@ ranking_test _test_ranking_function2 (
void bind_decision_functions()
{
using boost::python::arg;
add_linear_df<linear_kernel<sample_type> >("_decision_function_linear");
add_linear_df<sparse_linear_kernel<sparse_vect> >("_decision_function_sparse_linear");
......@@ -172,51 +174,82 @@ void bind_decision_functions()
add_df<sparse_sigmoid_kernel<sparse_vect> >("_decision_function_sparse_sigmoid");
def("test_binary_decision_function", _test_binary_decision_function<linear_kernel<sample_type> >);
def("test_binary_decision_function", _test_binary_decision_function<sparse_linear_kernel<sparse_vect> >);
def("test_binary_decision_function", _test_binary_decision_function<radial_basis_kernel<sample_type> >);
def("test_binary_decision_function", _test_binary_decision_function<sparse_radial_basis_kernel<sparse_vect> >);
def("test_binary_decision_function", _test_binary_decision_function<polynomial_kernel<sample_type> >);
def("test_binary_decision_function", _test_binary_decision_function<sparse_polynomial_kernel<sparse_vect> >);
def("test_binary_decision_function", _test_binary_decision_function<histogram_intersection_kernel<sample_type> >);
def("test_binary_decision_function", _test_binary_decision_function<sparse_histogram_intersection_kernel<sparse_vect> >);
def("test_binary_decision_function", _test_binary_decision_function<sigmoid_kernel<sample_type> >);
def("test_binary_decision_function", _test_binary_decision_function<sparse_sigmoid_kernel<sparse_vect> >);
def("test_regression_function", _test_regression_function<linear_kernel<sample_type> >);
def("test_regression_function", _test_regression_function<sparse_linear_kernel<sparse_vect> >);
def("test_regression_function", _test_regression_function<radial_basis_kernel<sample_type> >);
def("test_regression_function", _test_regression_function<sparse_radial_basis_kernel<sparse_vect> >);
def("test_regression_function", _test_regression_function<histogram_intersection_kernel<sample_type> >);
def("test_regression_function", _test_regression_function<sparse_histogram_intersection_kernel<sparse_vect> >);
def("test_regression_function", _test_regression_function<sigmoid_kernel<sample_type> >);
def("test_regression_function", _test_regression_function<sparse_sigmoid_kernel<sparse_vect> >);
def("test_regression_function", _test_regression_function<polynomial_kernel<sample_type> >);
def("test_regression_function", _test_regression_function<sparse_polynomial_kernel<sparse_vect> >);
def("test_ranking_function", _test_ranking_function1<linear_kernel<sample_type> >);
def("test_ranking_function", _test_ranking_function1<sparse_linear_kernel<sparse_vect> >);
def("test_ranking_function", _test_ranking_function2<linear_kernel<sample_type> >);
def("test_ranking_function", _test_ranking_function2<sparse_linear_kernel<sparse_vect> >);
def("test_binary_decision_function", _test_binary_decision_function<linear_kernel<sample_type> >,
(arg("function"), arg("samples"), arg("labels")));
def("test_binary_decision_function", _test_binary_decision_function<sparse_linear_kernel<sparse_vect> >,
(arg("function"), arg("samples"), arg("labels")));
def("test_binary_decision_function", _test_binary_decision_function<radial_basis_kernel<sample_type> >,
(arg("function"), arg("samples"), arg("labels")));
def("test_binary_decision_function", _test_binary_decision_function<sparse_radial_basis_kernel<sparse_vect> >,
(arg("function"), arg("samples"), arg("labels")));
def("test_binary_decision_function", _test_binary_decision_function<polynomial_kernel<sample_type> >,
(arg("function"), arg("samples"), arg("labels")));
def("test_binary_decision_function", _test_binary_decision_function<sparse_polynomial_kernel<sparse_vect> >,
(arg("function"), arg("samples"), arg("labels")));
def("test_binary_decision_function", _test_binary_decision_function<histogram_intersection_kernel<sample_type> >,
(arg("function"), arg("samples"), arg("labels")));
def("test_binary_decision_function", _test_binary_decision_function<sparse_histogram_intersection_kernel<sparse_vect> >,
(arg("function"), arg("samples"), arg("labels")));
def("test_binary_decision_function", _test_binary_decision_function<sigmoid_kernel<sample_type> >,
(arg("function"), arg("samples"), arg("labels")));
def("test_binary_decision_function", _test_binary_decision_function<sparse_sigmoid_kernel<sparse_vect> >,
(arg("function"), arg("samples"), arg("labels")));
def("test_regression_function", _test_regression_function<linear_kernel<sample_type> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_regression_function", _test_regression_function<sparse_linear_kernel<sparse_vect> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_regression_function", _test_regression_function<radial_basis_kernel<sample_type> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_regression_function", _test_regression_function<sparse_radial_basis_kernel<sparse_vect> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_regression_function", _test_regression_function<histogram_intersection_kernel<sample_type> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_regression_function", _test_regression_function<sparse_histogram_intersection_kernel<sparse_vect> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_regression_function", _test_regression_function<sigmoid_kernel<sample_type> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_regression_function", _test_regression_function<sparse_sigmoid_kernel<sparse_vect> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_regression_function", _test_regression_function<polynomial_kernel<sample_type> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_regression_function", _test_regression_function<sparse_polynomial_kernel<sparse_vect> >,
(arg("function"), arg("samples"), arg("targets")));
def("test_ranking_function", _test_ranking_function1<linear_kernel<sample_type> >,
(arg("function"), arg("samples")));
def("test_ranking_function", _test_ranking_function1<sparse_linear_kernel<sparse_vect> >,
(arg("function"), arg("samples")));
def("test_ranking_function", _test_ranking_function2<linear_kernel<sample_type> >,
(arg("function"), arg("sample")));
def("test_ranking_function", _test_ranking_function2<sparse_linear_kernel<sparse_vect> >,
(arg("function"), arg("sample")));
class_<binary_test>("_binary_test")
.add_property("class1_accuracy", &binary_test::class1_accuracy)
.def("__str__", binary_test__str__)
.def("__repr__", binary_test__repr__)
.add_property("class2_accuracy", &binary_test::class2_accuracy);
.add_property("class1_accuracy", &binary_test::class1_accuracy,
"A value between 0 and 1, measures accuracy on the +1 class.")
.add_property("class2_accuracy", &binary_test::class2_accuracy,
"A value between 0 and 1, measures accuracy on the -1 class.");
class_<ranking_test>("_ranking_test")
.add_property("ranking_accuracy", &ranking_test::ranking_accuracy)
.def("__str__", ranking_test__str__)
.def("__repr__", ranking_test__repr__)
.add_property("mean_ap", &ranking_test::mean_ap);
.add_property("ranking_accuracy", &ranking_test::ranking_accuracy,
"A value between 0 and 1, measures the fraction of times a relevant sample was ordered before a non-relevant sample.")
.add_property("mean_ap", &ranking_test::mean_ap,
"A value between 0 and 1, measures the mean average precision of the ranking.");
class_<regression_test>("_regression_test")
.add_property("mean_squared_error", &regression_test::mean_squared_error)
.def("__str__", regression_test__str__)
.def("__repr__", regression_test__repr__)
.add_property("R_squared", &regression_test::R_squared);
.add_property("mean_squared_error", &regression_test::mean_squared_error,
"The mean squared error of a regression function on a dataset.")
.add_property("R_squared", &regression_test::R_squared,
"A value between 0 and 1, measures the squared correlation between the output of a \n"
"regression function and the target values.");
}
......
......@@ -3,6 +3,8 @@
#include <dlib/matrix.h>
#include <dlib/data_io.h>
#include <dlib/sparse_vector.h>
#include <boost/python/args.hpp>
#include "pyassert.h"
using namespace dlib;
using namespace std;
......@@ -10,27 +12,6 @@ using namespace boost::python;
typedef std::vector<std::pair<unsigned long,double> > sparse_vect;
tuple get_training_data()
{
typedef matrix<double,0,1> sample_type;
std::vector<sample_type> samples;
std::vector<double> labels;
sample_type samp(3);
for (int i = 0; i < 10; ++i)
{
samp = 1,2,3;
samples.push_back(samp);
labels.push_back(+1);
samp = -1,-2,-3;
samples.push_back(samp);
labels.push_back(-1);
}
return make_tuple(samples, labels);
}
void _make_sparse_vector (
sparse_vect& v
......@@ -61,14 +42,34 @@ void _save_libsvm_formatted_data (
const std::string& file_name,
const std::vector<sparse_vect>& samples,
const std::vector<double>& labels
) { save_libsvm_formatted_data(file_name, samples, labels); }
)
{
pyassert(samples.size() == labels.size(), "Invalid inputs");
save_libsvm_formatted_data(file_name, samples, labels);
}
void bind_other()
{
def("get_training_data",get_training_data);
def("make_sparse_vector", _make_sparse_vector , "This function modifies its argument so that it is a properly sorted sparse vector.");
def("make_sparse_vector", _make_sparse_vector2 , "This function modifies a sparse_vectors object so that all elements it contains are properly sorted sparse vectors.");
def("load_libsvm_formatted_data",_load_libsvm_formatted_data);
def("save_libsvm_formatted_data",_save_libsvm_formatted_data);
using boost::python::arg;
def("make_sparse_vector", _make_sparse_vector ,
"This function modifies its argument so that it is a properly sorted sparse vector.");
def("make_sparse_vector", _make_sparse_vector2 ,
"This function modifies a sparse_vectors object so that all elements it contains are properly sorted sparse vectors.");
def("load_libsvm_formatted_data",_load_libsvm_formatted_data, (arg("file_name")),
"ensures \n\
- Attempts to read a file of the given name that should contain libsvm \n\
formatted data. The data is returned as a tuple where the first tuple \n\
element is an array of sparse vectors and the second element is an array of \n\
labels. "
);
def("save_libsvm_formatted_data",_save_libsvm_formatted_data, (arg("file_name"), arg("samples"), arg("labels")),
"requires \n\
- len(samples) == len(labels) \n\
ensures \n\
- saves the data to the given file in libsvm format "
);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment