Commit 859ccf5e authored by Davis King's avatar Davis King

Added some cross validation wrappers.

parent b8f2b522
#include "testing_results.h"
#include <boost/python.hpp> #include <boost/python.hpp>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include "serialize_pickle.h" #include "serialize_pickle.h"
...@@ -100,18 +101,6 @@ void add_linear_df ( ...@@ -100,18 +101,6 @@ void add_linear_df (
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
struct binary_test
{
binary_test() : class1_accuracy(0), class2_accuracy(0) {}
binary_test(
const matrix<double,1,2>& m
) : class1_accuracy(m(0)),
class2_accuracy(m(1)) {}
double class1_accuracy;
double class2_accuracy;
};
std::string binary_test__str__(const binary_test& item) std::string binary_test__str__(const binary_test& item)
{ {
std::ostringstream sout; std::ostringstream sout;
...@@ -120,18 +109,6 @@ std::string binary_test__str__(const binary_test& item) ...@@ -120,18 +109,6 @@ std::string binary_test__str__(const binary_test& item)
} }
std::string binary_test__repr__(const binary_test& item) { return "< " + binary_test__str__(item) + " >";} std::string binary_test__repr__(const binary_test& item) { return "< " + binary_test__str__(item) + " >";}
struct regression_test
{
regression_test() : mean_squared_error(0), R_squared(0) {}
regression_test(
const matrix<double,1,2>& m
) : mean_squared_error(m(0)),
R_squared(m(1)) {}
double mean_squared_error;
double R_squared;
};
std::string regression_test__str__(const regression_test& item) std::string regression_test__str__(const regression_test& item)
{ {
std::ostringstream sout; std::ostringstream sout;
...@@ -140,18 +117,6 @@ std::string regression_test__str__(const regression_test& item) ...@@ -140,18 +117,6 @@ std::string regression_test__str__(const regression_test& item)
} }
std::string regression_test__repr__(const regression_test& item) { return "< " + regression_test__str__(item) + " >";} std::string regression_test__repr__(const regression_test& item) { return "< " + regression_test__str__(item) + " >";}
struct ranking_test
{
ranking_test() : ranking_accuracy(0), mean_ap(0) {}
ranking_test(
const matrix<double,1,2>& m
) : ranking_accuracy(m(0)),
mean_ap(m(1)) {}
double ranking_accuracy;
double mean_ap;
};
std::string ranking_test__str__(const ranking_test& item) std::string ranking_test__str__(const ranking_test& item)
{ {
std::ostringstream sout; std::ostringstream sout;
......
...@@ -13,12 +13,17 @@ tuple get_training_data() ...@@ -13,12 +13,17 @@ tuple get_training_data()
std::vector<double> labels; std::vector<double> labels;
sample_type samp(3); sample_type samp(3);
samp = 1,2,3;
samples.push_back(samp); for (int i = 0; i < 10; ++i)
labels.push_back(+1); {
samp = -1,-2,-3; samp = 1,2,3;
samples.push_back(samp); samples.push_back(samp);
labels.push_back(-1); labels.push_back(+1);
samp = -1,-2,-3;
samples.push_back(samp);
labels.push_back(-1);
}
return make_tuple(samples, labels); return make_tuple(samples, labels);
} }
......
#include "testing_results.h"
#include <boost/python.hpp> #include <boost/python.hpp>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <dlib/matrix.h> #include <dlib/matrix.h>
#include "serialize_pickle.h" #include "serialize_pickle.h"
#include <dlib/svm.h> #include <dlib/svm_threaded.h>
#include "pyassert.h" #include "pyassert.h"
using namespace dlib; using namespace dlib;
...@@ -118,6 +119,39 @@ double get_gamma_sparse ( ...@@ -118,6 +119,39 @@ double get_gamma_sparse (
return trainer.get_kernel().gamma; return trainer.get_kernel().gamma;
} }
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
const binary_test _cross_validate_trainer (
const trainer_type& trainer,
const std::vector<typename trainer_type::sample_type>& x,
const std::vector<double>& y,
const long folds
)
{
pyassert(is_binary_classification_problem(x,y), "Training data does not make a valid training set.");
pyassert(1 < folds && folds <= x.size(), "Invalid number of folds given.");
return cross_validate_trainer(trainer, x, y, folds);
}
template <
typename trainer_type
>
const binary_test _cross_validate_trainer_t (
const trainer_type& trainer,
const std::vector<typename trainer_type::sample_type>& x,
const std::vector<double>& y,
const unsigned long folds,
const unsigned long num_threads
)
{
pyassert(is_binary_classification_problem(x,y), "Training data does not make a valid training set.");
pyassert(1 < folds && folds <= x.size(), "Invalid number of folds given.");
pyassert(1 < num_threads, "The number of threads specified must not be zero.");
return cross_validate_trainer_threaded(trainer, x, y, folds, num_threads);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -125,13 +159,21 @@ void bind_svm_c_trainer() ...@@ -125,13 +159,21 @@ void bind_svm_c_trainer()
{ {
setup_trainer<svm_c_trainer<radial_basis_kernel<sample_type> > >("svm_c_trainer_radial_basis") setup_trainer<svm_c_trainer<radial_basis_kernel<sample_type> > >("svm_c_trainer_radial_basis")
.add_property("gamma", get_gamma, set_gamma); .add_property("gamma", get_gamma, set_gamma);
def("cross_validate_trainer", _cross_validate_trainer<svm_c_trainer<radial_basis_kernel<sample_type> > >);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<svm_c_trainer<radial_basis_kernel<sample_type> > >);
setup_trainer<svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> > >("svm_c_trainer_sparse_radial_basis") setup_trainer<svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> > >("svm_c_trainer_sparse_radial_basis")
.add_property("gamma", get_gamma, set_gamma); .add_property("gamma", get_gamma, set_gamma);
def("cross_validate_trainer", _cross_validate_trainer<svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> > >);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> > >);
setup_trainer<svm_c_trainer<histogram_intersection_kernel<sample_type> > >("svm_c_trainer_histogram_intersection"); setup_trainer<svm_c_trainer<histogram_intersection_kernel<sample_type> > >("svm_c_trainer_histogram_intersection");
def("cross_validate_trainer", _cross_validate_trainer<svm_c_trainer<histogram_intersection_kernel<sample_type> > >);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<svm_c_trainer<histogram_intersection_kernel<sample_type> > >);
setup_trainer<svm_c_trainer<sparse_histogram_intersection_kernel<sparse_vect> > >("svm_c_trainer_sparse_histogram_intersection"); setup_trainer<svm_c_trainer<sparse_histogram_intersection_kernel<sparse_vect> > >("svm_c_trainer_sparse_histogram_intersection");
def("cross_validate_trainer", _cross_validate_trainer<svm_c_trainer<sparse_histogram_intersection_kernel<sparse_vect> > >);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<svm_c_trainer<sparse_histogram_intersection_kernel<sparse_vect> > >);
} }
#ifndef DLIB_TESTING_ReSULTS_H__
#define DLIB_TESTING_ReSULTS_H__
#include <dlib/matrix.h>
struct binary_test
{
binary_test() : class1_accuracy(0), class2_accuracy(0) {}
binary_test(
const dlib::matrix<double,1,2>& m
) : class1_accuracy(m(0)),
class2_accuracy(m(1)) {}
double class1_accuracy;
double class2_accuracy;
};
struct regression_test
{
regression_test() : mean_squared_error(0), R_squared(0) {}
regression_test(
const dlib::matrix<double,1,2>& m
) : mean_squared_error(m(0)),
R_squared(m(1)) {}
double mean_squared_error;
double R_squared;
};
struct ranking_test
{
ranking_test() : ranking_accuracy(0), mean_ap(0) {}
ranking_test(
const dlib::matrix<double,1,2>& m
) : ranking_accuracy(m(0)),
mean_ap(m(1)) {}
double ranking_accuracy;
double mean_ap;
};
#endif // DLIB_TESTING_ReSULTS_H__
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment