Commit 14cbb804 authored by ksemb's avatar ksemb Committed by Davis E. King

Add Python rvm_trainer and init functions (#1100)

parent b2a0bae4
......@@ -72,56 +72,51 @@ template <typename trainer_type>
double get_c_class2 ( const trainer_type& trainer) { return trainer.get_c_class2(); }
template <typename trainer_type>
py::class_<trainer_type> setup_trainer (
py::class_<trainer_type> setup_trainer_eps (
py::module& m,
const std::string& name
)
{
return py::class_<trainer_type>(m, name.c_str())
.def("train", train<trainer_type>)
.def("set_c", set_c<trainer_type>)
.def_property("c_class1", get_c_class1<trainer_type>, set_c_class1<trainer_type>)
.def_property("c_class2", get_c_class2<trainer_type>, set_c_class2<trainer_type>)
.def_property("epsilon", get_epsilon<trainer_type>, set_epsilon<trainer_type>);
}
template <typename trainer_type>
py::class_<trainer_type> setup_trainer2 (
py::class_<trainer_type> setup_trainer_eps_c (
py::module& m,
const std::string& name
)
{
return setup_trainer<trainer_type>(m, name)
.def_property("cache_size", get_cache_size<trainer_type>, set_cache_size<trainer_type>);
}
void set_gamma (
svm_c_trainer<radial_basis_kernel<sample_type> >& trainer,
double gamma
)
{
pyassert(gamma > 0, "gamma must be > 0");
trainer.set_kernel(radial_basis_kernel<sample_type>(gamma));
return setup_trainer_eps<trainer_type>(m, name)
.def("set_c", set_c<trainer_type>)
.def_property("c_class1", get_c_class1<trainer_type>, set_c_class1<trainer_type>)
.def_property("c_class2", get_c_class2<trainer_type>, set_c_class2<trainer_type>);
}
double get_gamma (
const svm_c_trainer<radial_basis_kernel<sample_type> >& trainer
template <typename trainer_type>
py::class_<trainer_type> setup_trainer_eps_c_cache (
py::module& m,
const std::string& name
)
{
return trainer.get_kernel().gamma;
return setup_trainer_eps_c<trainer_type>(m, name)
.def_property("cache_size", get_cache_size<trainer_type>, set_cache_size<trainer_type>);
}
void set_gamma_sparse (
svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> >& trainer,
template <typename trainer_type>
void set_gamma (
trainer_type& trainer,
double gamma
)
{
pyassert(gamma > 0, "gamma must be > 0");
trainer.set_kernel(sparse_radial_basis_kernel<sparse_vect>(gamma));
trainer.set_kernel(typename trainer_type::kernel_type(gamma));
}
double get_gamma_sparse (
const svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> >& trainer
template <typename trainer_type>
double get_gamma (
const trainer_type& trainer
)
{
return trainer.get_kernel().gamma;
......@@ -166,10 +161,13 @@ const binary_test _cross_validate_trainer_t (
void bind_svm_c_trainer(py::module& m)
{
namespace py = pybind11;
// svm_c
{
typedef svm_c_trainer<radial_basis_kernel<sample_type> > T;
setup_trainer2<T>(m, "svm_c_trainer_radial_basis")
.def_property("gamma", get_gamma, set_gamma);
setup_trainer_eps_c_cache<T>(m, "svm_c_trainer_radial_basis")
.def(py::init())
.def_property("gamma", get_gamma<T>, set_gamma<T>);
m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
......@@ -178,8 +176,9 @@ void bind_svm_c_trainer(py::module& m)
{
typedef svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> > T;
setup_trainer2<T>(m, "svm_c_trainer_sparse_radial_basis")
.def_property("gamma", get_gamma_sparse, set_gamma_sparse);
setup_trainer_eps_c_cache<T>(m, "svm_c_trainer_sparse_radial_basis")
.def(py::init())
.def_property("gamma", get_gamma<T>, set_gamma<T>);
m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
......@@ -188,7 +187,8 @@ void bind_svm_c_trainer(py::module& m)
{
typedef svm_c_trainer<histogram_intersection_kernel<sample_type> > T;
setup_trainer2<T>(m, "svm_c_trainer_histogram_intersection");
setup_trainer_eps_c_cache<T>(m, "svm_c_trainer_histogram_intersection")
.def(py::init());
m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
......@@ -197,16 +197,18 @@ void bind_svm_c_trainer(py::module& m)
{
typedef svm_c_trainer<sparse_histogram_intersection_kernel<sparse_vect> > T;
setup_trainer2<T>(m, "svm_c_trainer_sparse_histogram_intersection");
setup_trainer_eps_c_cache<T>(m, "svm_c_trainer_sparse_histogram_intersection")
.def(py::init());
m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
}
// svm_c_linear
{
typedef svm_c_linear_trainer<linear_kernel<sample_type> > T;
setup_trainer<T>(m, "svm_c_trainer_linear")
setup_trainer_eps_c<T>(m, "svm_c_trainer_linear")
.def(py::init())
.def_property("max_iterations", &T::get_max_iterations, &T::set_max_iterations)
.def_property("force_last_weight_to_1", &T::forces_last_weight_to_1, &T::force_last_weight_to_1)
......@@ -224,7 +226,8 @@ void bind_svm_c_trainer(py::module& m)
{
typedef svm_c_linear_trainer<sparse_linear_kernel<sparse_vect> > T;
setup_trainer<T>(m, "svm_c_trainer_sparse_linear")
setup_trainer_eps_c<T>(m, "svm_c_trainer_sparse_linear")
.def(py::init())
.def_property("max_iterations", &T::get_max_iterations, &T::set_max_iterations)
.def_property("force_last_weight_to_1", &T::forces_last_weight_to_1, &T::force_last_weight_to_1)
.def_property("learns_nonnegative_weights", &T::learns_nonnegative_weights, &T::set_learns_nonnegative_weights)
......@@ -238,6 +241,70 @@ void bind_svm_c_trainer(py::module& m)
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
}
// rvm
{
typedef rvm_trainer<radial_basis_kernel<sample_type> > T;
setup_trainer_eps<T>(m, "rvm_trainer_radial_basis")
.def(py::init())
.def_property("gamma", get_gamma<T>, set_gamma<T>);
m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
}
{
typedef rvm_trainer<sparse_radial_basis_kernel<sparse_vect> > T;
setup_trainer_eps<T>(m, "rvm_trainer_sparse_radial_basis")
.def(py::init())
.def_property("gamma", get_gamma<T>, set_gamma<T>);
m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
}
{
typedef rvm_trainer<histogram_intersection_kernel<sample_type> > T;
setup_trainer_eps<T>(m, "rvm_trainer_histogram_intersection")
.def(py::init());
m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
}
{
typedef rvm_trainer<sparse_histogram_intersection_kernel<sparse_vect> > T;
setup_trainer_eps<T>(m, "rvm_trainer_sparse_histogram_intersection")
.def(py::init());
m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
}
// rvm linear
{
typedef rvm_trainer<linear_kernel<sample_type> > T;
setup_trainer_eps<T>(m, "rvm_trainer_linear")
.def(py::init());
m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
}
{
typedef rvm_trainer<sparse_linear_kernel<sparse_vect> > T;
setup_trainer_eps<T>(m, "rvm_trainer_sparse_linear")
.def(py::init());
m.def("cross_validate_trainer", _cross_validate_trainer<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"));
m.def("cross_validate_trainer_threaded", _cross_validate_trainer_t<T>,
py::arg("trainer"),py::arg("x"),py::arg("y"),py::arg("folds"),py::arg("num_threads"));
}
}
from __future__ import division
import pytest
from random import Random
from dlib import (vectors, vector, sparse_vectors, sparse_vector, pair, array,
cross_validate_trainer,
svm_c_trainer_radial_basis,
svm_c_trainer_sparse_radial_basis,
svm_c_trainer_histogram_intersection,
svm_c_trainer_sparse_histogram_intersection,
svm_c_trainer_linear,
svm_c_trainer_sparse_linear,
rvm_trainer_radial_basis,
rvm_trainer_sparse_radial_basis,
rvm_trainer_histogram_intersection,
rvm_trainer_sparse_histogram_intersection,
rvm_trainer_linear,
rvm_trainer_sparse_linear)
@pytest.fixture
def training_data():
r = Random(0)
predictors = vectors()
sparse_predictors = sparse_vectors()
response = array()
for i in range(30):
for c in [-1, 1]:
response.append(c)
values = [r.random() + c * 0.5 for _ in range(3)]
predictors.append(vector(values))
sp = sparse_vector()
for i, v in enumerate(values):
sp.append(pair(i, v))
sparse_predictors.append(sp)
return predictors, sparse_predictors, response
@pytest.mark.parametrize('trainer, class1_accuracy, class2_accuracy', [
(svm_c_trainer_radial_basis, 1.0, 1.0),
(svm_c_trainer_sparse_radial_basis, 1.0, 1.0),
(svm_c_trainer_histogram_intersection, 1.0, 1.0),
(svm_c_trainer_sparse_histogram_intersection, 1.0, 1.0),
(svm_c_trainer_linear, 1.0, 23 / 30),
(svm_c_trainer_sparse_linear, 1.0, 23 / 30),
(rvm_trainer_radial_basis, 1.0, 1.0),
(rvm_trainer_sparse_radial_basis, 1.0, 1.0),
(rvm_trainer_histogram_intersection, 1.0, 1.0),
(rvm_trainer_sparse_histogram_intersection, 1.0, 1.0),
(rvm_trainer_linear, 1.0, 0.6),
(rvm_trainer_sparse_linear, 1.0, 0.6)
])
def test_trainers(training_data, trainer, class1_accuracy, class2_accuracy):
predictors, sparse_predictors, response = training_data
if 'sparse' in trainer.__name__:
predictors = sparse_predictors
cv = cross_validate_trainer(trainer(), predictors, response, folds=10)
assert cv.class1_accuracy == pytest.approx(class1_accuracy)
assert cv.class2_accuracy == pytest.approx(class2_accuracy)
decision_function = trainer().train(predictors, response)
assert decision_function(predictors[2]) < 0
assert decision_function(predictors[3]) > 0
if 'linear' in trainer.__name__:
assert len(decision_function.weights) == 3
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment