Commit 14cbb804 authored by ksemb's avatar ksemb Committed by Davis E. King

Add Python rvm_trainer and init functions (#1100)

parent b2a0bae4
...@@ -72,56 +72,51 @@ template <typename trainer_type> ...@@ -72,56 +72,51 @@ template <typename trainer_type>
double get_c_class2 ( const trainer_type& trainer) { return trainer.get_c_class2(); } double get_c_class2 ( const trainer_type& trainer) { return trainer.get_c_class2(); }
template <typename trainer_type> template <typename trainer_type>
py::class_<trainer_type> setup_trainer ( py::class_<trainer_type> setup_trainer_eps (
py::module& m, py::module& m,
const std::string& name const std::string& name
) )
{ {
return py::class_<trainer_type>(m, name.c_str()) return py::class_<trainer_type>(m, name.c_str())
.def("train", train<trainer_type>) .def("train", train<trainer_type>)
.def("set_c", set_c<trainer_type>)
.def_property("c_class1", get_c_class1<trainer_type>, set_c_class1<trainer_type>)
.def_property("c_class2", get_c_class2<trainer_type>, set_c_class2<trainer_type>)
.def_property("epsilon", get_epsilon<trainer_type>, set_epsilon<trainer_type>); .def_property("epsilon", get_epsilon<trainer_type>, set_epsilon<trainer_type>);
} }
template <typename trainer_type> template <typename trainer_type>
py::class_<trainer_type> setup_trainer2 ( py::class_<trainer_type> setup_trainer_eps_c (
py::module& m, py::module& m,
const std::string& name const std::string& name
) )
{ {
return setup_trainer<trainer_type>(m, name) return setup_trainer_eps<trainer_type>(m, name)
.def_property("cache_size", get_cache_size<trainer_type>, set_cache_size<trainer_type>); .def("set_c", set_c<trainer_type>)
} .def_property("c_class1", get_c_class1<trainer_type>, set_c_class1<trainer_type>)
.def_property("c_class2", get_c_class2<trainer_type>, set_c_class2<trainer_type>);
void set_gamma (
svm_c_trainer<radial_basis_kernel<sample_type> >& trainer,
double gamma
)
{
pyassert(gamma > 0, "gamma must be > 0");
trainer.set_kernel(radial_basis_kernel<sample_type>(gamma));
} }
double get_gamma ( template <typename trainer_type>
const svm_c_trainer<radial_basis_kernel<sample_type> >& trainer py::class_<trainer_type> setup_trainer_eps_c_cache (
py::module& m,
const std::string& name
) )
{ {
return trainer.get_kernel().gamma; return setup_trainer_eps_c<trainer_type>(m, name)
.def_property("cache_size", get_cache_size<trainer_type>, set_cache_size<trainer_type>);
} }
void set_gamma_sparse ( template <typename trainer_type>
svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> >& trainer, void set_gamma (
trainer_type& trainer,
double gamma double gamma
) )
{ {
pyassert(gamma > 0, "gamma must be > 0"); pyassert(gamma > 0, "gamma must be > 0");
trainer.set_kernel(sparse_radial_basis_kernel<sparse_vect>(gamma)); trainer.set_kernel(typename trainer_type::kernel_type(gamma));
} }
double get_gamma_sparse ( template <typename trainer_type>
const svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> >& trainer double get_gamma (
const trainer_type& trainer
) )
{ {
return trainer.get_kernel().gamma; return trainer.get_kernel().gamma;
...@@ -166,10 +161,13 @@ const binary_test _cross_validate_trainer_t ( ...@@ -166,10 +161,13 @@ const binary_test _cross_validate_trainer_t (
void bind_svm_c_trainer(py::module& m) void bind_svm_c_trainer(py::module& m)
{ {
namespace py = pybind11; namespace py = pybind11;
// svm_c
{ {
typedef svm_c_trainer<radial_basis_kernel<sample_type> > T; typedef svm_c_trainer<radial_basis_kernel<sample_type> > T;
setup_trainer2<T>(m, "svm_c_trainer_radial_basis") setup_trainer_eps_c_cache<T>(m, "svm_c_trainer_radial_basis")
.def_property("gamma", get_gamma, set_gamma); .def(py::init())
.def_property("gamma", get_gamma<T>, set_gamma<T>);
m.def("cross_validate_trainer", _cross_validate_trainer<T>, m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds")); py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
...@@ -178,8 +176,9 @@ void bind_svm_c_trainer(py::module& m) ...@@ -178,8 +176,9 @@ void bind_svm_c_trainer(py::module& m)
{ {
typedef svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> > T; typedef svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> > T;
setup_trainer2<T>(m, "svm_c_trainer_sparse_radial_basis") setup_trainer_eps_c_cache<T>(m, "svm_c_trainer_sparse_radial_basis")
.def_property("gamma", get_gamma_sparse, set_gamma_sparse); .def(py::init())
.def_property("gamma", get_gamma<T>, set_gamma<T>);
m.def("cross_validate_trainer", _cross_validate_trainer<T>, m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds")); py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
...@@ -188,7 +187,8 @@ void bind_svm_c_trainer(py::module& m) ...@@ -188,7 +187,8 @@ void bind_svm_c_trainer(py::module& m)
{ {
typedef svm_c_trainer<histogram_intersection_kernel<sample_type> > T; typedef svm_c_trainer<histogram_intersection_kernel<sample_type> > T;
setup_trainer2<T>(m, "svm_c_trainer_histogram_intersection"); setup_trainer_eps_c_cache<T>(m, "svm_c_trainer_histogram_intersection")
.def(py::init());
m.def("cross_validate_trainer", _cross_validate_trainer<T>, m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds")); py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
...@@ -197,16 +197,18 @@ void bind_svm_c_trainer(py::module& m) ...@@ -197,16 +197,18 @@ void bind_svm_c_trainer(py::module& m)
{ {
typedef svm_c_trainer<sparse_histogram_intersection_kernel<sparse_vect> > T; typedef svm_c_trainer<sparse_histogram_intersection_kernel<sparse_vect> > T;
setup_trainer2<T>(m, "svm_c_trainer_sparse_histogram_intersection"); setup_trainer_eps_c_cache<T>(m, "svm_c_trainer_sparse_histogram_intersection")
.def(py::init());
m.def("cross_validate_trainer", _cross_validate_trainer<T>, m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds")); py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads")); py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
} }
// svm_c_linear
{ {
typedef svm_c_linear_trainer<linear_kernel<sample_type> > T; typedef svm_c_linear_trainer<linear_kernel<sample_type> > T;
setup_trainer<T>(m, "svm_c_trainer_linear") setup_trainer_eps_c<T>(m, "svm_c_trainer_linear")
.def(py::init()) .def(py::init())
.def_property("max_iterations", &T::get_max_iterations, &T::set_max_iterations) .def_property("max_iterations", &T::get_max_iterations, &T::set_max_iterations)
.def_property("force_last_weight_to_1", &T::forces_last_weight_to_1, &T::force_last_weight_to_1) .def_property("force_last_weight_to_1", &T::forces_last_weight_to_1, &T::force_last_weight_to_1)
...@@ -224,7 +226,8 @@ void bind_svm_c_trainer(py::module& m) ...@@ -224,7 +226,8 @@ void bind_svm_c_trainer(py::module& m)
{ {
typedef svm_c_linear_trainer<sparse_linear_kernel<sparse_vect> > T; typedef svm_c_linear_trainer<sparse_linear_kernel<sparse_vect> > T;
setup_trainer<T>(m, "svm_c_trainer_sparse_linear") setup_trainer_eps_c<T>(m, "svm_c_trainer_sparse_linear")
.def(py::init())
.def_property("max_iterations", &T::get_max_iterations, &T::set_max_iterations) .def_property("max_iterations", &T::get_max_iterations, &T::set_max_iterations)
.def_property("force_last_weight_to_1", &T::forces_last_weight_to_1, &T::force_last_weight_to_1) .def_property("force_last_weight_to_1", &T::forces_last_weight_to_1, &T::force_last_weight_to_1)
.def_property("learns_nonnegative_weights", &T::learns_nonnegative_weights, &T::set_learns_nonnegative_weights) .def_property("learns_nonnegative_weights", &T::learns_nonnegative_weights, &T::set_learns_nonnegative_weights)
...@@ -238,6 +241,70 @@ void bind_svm_c_trainer(py::module& m) ...@@ -238,6 +241,70 @@ void bind_svm_c_trainer(py::module& m)
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>, m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads")); py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
} }
// rvm
{
typedef rvm_trainer<radial_basis_kernel<sample_type> > T;
setup_trainer_eps<T>(m, "rvm_trainer_radial_basis")
.def(py::init())
.def_property("gamma", get_gamma<T>, set_gamma<T>);
m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
}
{
typedef rvm_trainer<sparse_radial_basis_kernel<sparse_vect> > T;
setup_trainer_eps<T>(m, "rvm_trainer_sparse_radial_basis")
.def(py::init())
.def_property("gamma", get_gamma<T>, set_gamma<T>);
m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
}
{
typedef rvm_trainer<histogram_intersection_kernel<sample_type> > T;
setup_trainer_eps<T>(m, "rvm_trainer_histogram_intersection")
.def(py::init());
m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
}
{
typedef rvm_trainer<sparse_histogram_intersection_kernel<sparse_vect> > T;
setup_trainer_eps<T>(m, "rvm_trainer_sparse_histogram_intersection")
.def(py::init());
m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
}
// rvm linear
{
typedef rvm_trainer<linear_kernel<sample_type> > T;
setup_trainer_eps<T>(m, "rvm_trainer_linear")
.def(py::init());
m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
}
{
typedef rvm_trainer<sparse_linear_kernel<sparse_vect> > T;
setup_trainer_eps<T>(m, "rvm_trainer_sparse_linear")
.def(py::init());
m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
}
} }
from __future__ import division
import pytest
from random import Random
from dlib import (vectors, vector, sparse_vectors, sparse_vector, pair, array,
cross_validate_trainer,
svm_c_trainer_radial_basis,
svm_c_trainer_sparse_radial_basis,
svm_c_trainer_histogram_intersection,
svm_c_trainer_sparse_histogram_intersection,
svm_c_trainer_linear,
svm_c_trainer_sparse_linear,
rvm_trainer_radial_basis,
rvm_trainer_sparse_radial_basis,
rvm_trainer_histogram_intersection,
rvm_trainer_sparse_histogram_intersection,
rvm_trainer_linear,
rvm_trainer_sparse_linear)
@pytest.fixture
def training_data():
r = Random(0)
predictors = vectors()
sparse_predictors = sparse_vectors()
response = array()
for i in range(30):
for c in [-1, 1]:
response.append(c)
values = [r.random() + c * 0.5 for _ in range(3)]
predictors.append(vector(values))
sp = sparse_vector()
for i, v in enumerate(values):
sp.append(pair(i, v))
sparse_predictors.append(sp)
return predictors, sparse_predictors, response
@pytest.mark.parametrize('trainer, class1_accuracy, class2_accuracy', [
(svm_c_trainer_radial_basis, 1.0, 1.0),
(svm_c_trainer_sparse_radial_basis, 1.0, 1.0),
(svm_c_trainer_histogram_intersection, 1.0, 1.0),
(svm_c_trainer_sparse_histogram_intersection, 1.0, 1.0),
(svm_c_trainer_linear, 1.0, 23 / 30),
(svm_c_trainer_sparse_linear, 1.0, 23 / 30),
(rvm_trainer_radial_basis, 1.0, 1.0),
(rvm_trainer_sparse_radial_basis, 1.0, 1.0),
(rvm_trainer_histogram_intersection, 1.0, 1.0),
(rvm_trainer_sparse_histogram_intersection, 1.0, 1.0),
(rvm_trainer_linear, 1.0, 0.6),
(rvm_trainer_sparse_linear, 1.0, 0.6)
])
def test_trainers(training_data, trainer, class1_accuracy, class2_accuracy):
predictors, sparse_predictors, response = training_data
if 'sparse' in trainer.__name__:
predictors = sparse_predictors
cv = cross_validate_trainer(trainer(), predictors, response, folds=10)
assert cv.class1_accuracy == pytest.approx(class1_accuracy)
assert cv.class2_accuracy == pytest.approx(class2_accuracy)
decision_function = trainer().train(predictors, response)
assert decision_function(predictors[2]) < 0
assert decision_function(predictors[3]) > 0
if 'linear' in trainer.__name__:
assert len(decision_function.weights) == 3
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment