Commit 859ccf5e authored by Davis King's avatar Davis King

Added some cross validation wrappers.

parent b8f2b522
#include "testing_results.h"
#include <boost/python.hpp>
#include <boost/shared_ptr.hpp>
#include "serialize_pickle.h"
......@@ -100,18 +101,6 @@ void add_linear_df (
// ----------------------------------------------------------------------------------------
struct binary_test
{
binary_test() : class1_accuracy(0), class2_accuracy(0) {}
binary_test(
const matrix<double,1,2>& m
) : class1_accuracy(m(0)),
class2_accuracy(m(1)) {}
double class1_accuracy;
double class2_accuracy;
};
std::string binary_test__str__(const binary_test& item)
{
std::ostringstream sout;
......@@ -120,18 +109,6 @@ std::string binary_test__str__(const binary_test& item)
}
std::string binary_test__repr__(const binary_test& item) { return "< " + binary_test__str__(item) + " >";}
struct regression_test
{
regression_test() : mean_squared_error(0), R_squared(0) {}
regression_test(
const matrix<double,1,2>& m
) : mean_squared_error(m(0)),
R_squared(m(1)) {}
double mean_squared_error;
double R_squared;
};
std::string regression_test__str__(const regression_test& item)
{
std::ostringstream sout;
......@@ -140,18 +117,6 @@ std::string regression_test__str__(const regression_test& item)
}
std::string regression_test__repr__(const regression_test& item) { return "< " + regression_test__str__(item) + " >";}
struct ranking_test
{
ranking_test() : ranking_accuracy(0), mean_ap(0) {}
ranking_test(
const matrix<double,1,2>& m
) : ranking_accuracy(m(0)),
mean_ap(m(1)) {}
double ranking_accuracy;
double mean_ap;
};
std::string ranking_test__str__(const ranking_test& item)
{
std::ostringstream sout;
......
......@@ -13,12 +13,17 @@ tuple get_training_data()
std::vector<double> labels;
sample_type samp(3);
samp = 1,2,3;
samples.push_back(samp);
labels.push_back(+1);
samp = -1,-2,-3;
samples.push_back(samp);
labels.push_back(-1);
for (int i = 0; i < 10; ++i)
{
samp = 1,2,3;
samples.push_back(samp);
labels.push_back(+1);
samp = -1,-2,-3;
samples.push_back(samp);
labels.push_back(-1);
}
return make_tuple(samples, labels);
}
......
#include "testing_results.h"
#include <boost/python.hpp>
#include <boost/shared_ptr.hpp>
#include <dlib/matrix.h>
#include "serialize_pickle.h"
#include <dlib/svm.h>
#include <dlib/svm_threaded.h>
#include "pyassert.h"
using namespace dlib;
......@@ -118,6 +119,39 @@ double get_gamma_sparse (
return trainer.get_kernel().gamma;
}
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
const binary_test _cross_validate_trainer (
const trainer_type& trainer,
const std::vector<typename trainer_type::sample_type>& x,
const std::vector<double>& y,
const long folds
)
{
pyassert(is_binary_classification_problem(x,y), "Training data does not make a valid training set.");
pyassert(1 < folds && folds <= x.size(), "Invalid number of folds given.");
return cross_validate_trainer(trainer, x, y, folds);
}
template <
typename trainer_type
>
const binary_test _cross_validate_trainer_t (
const trainer_type& trainer,
const std::vector<typename trainer_type::sample_type>& x,
const std::vector<double>& y,
const unsigned long folds,
const unsigned long num_threads
)
{
pyassert(is_binary_classification_problem(x,y), "Training data does not make a valid training set.");
pyassert(1 < folds && folds <= x.size(), "Invalid number of folds given.");
pyassert(1 < num_threads, "The number of threads specified must not be zero.");
return cross_validate_trainer_threaded(trainer, x, y, folds, num_threads);
}
// ----------------------------------------------------------------------------------------
......@@ -125,13 +159,21 @@ void bind_svm_c_trainer()
{
setup_trainer<svm_c_trainer<radial_basis_kernel<sample_type> > >("svm_c_trainer_radial_basis")
.add_property("gamma", get_gamma, set_gamma);
def("cross_validate_trainer", _cross_validate_trainer<svm_c_trainer<radial_basis_kernel<sample_type> > >);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<svm_c_trainer<radial_basis_kernel<sample_type> > >);
setup_trainer<svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> > >("svm_c_trainer_sparse_radial_basis")
.add_property("gamma", get_gamma, set_gamma);
def("cross_validate_trainer", _cross_validate_trainer<svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> > >);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<svm_c_trainer<sparse_radial_basis_kernel<sparse_vect> > >);
setup_trainer<svm_c_trainer<histogram_intersection_kernel<sample_type> > >("svm_c_trainer_histogram_intersection");
def("cross_validate_trainer", _cross_validate_trainer<svm_c_trainer<histogram_intersection_kernel<sample_type> > >);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<svm_c_trainer<histogram_intersection_kernel<sample_type> > >);
setup_trainer<svm_c_trainer<sparse_histogram_intersection_kernel<sparse_vect> > >("svm_c_trainer_sparse_histogram_intersection");
def("cross_validate_trainer", _cross_validate_trainer<svm_c_trainer<sparse_histogram_intersection_kernel<sparse_vect> > >);
def("cross_validate_trainer_threaded", _cross_validate_trainer_t<svm_c_trainer<sparse_histogram_intersection_kernel<sparse_vect> > >);
}
#ifndef DLIB_TESTING_ReSULTS_H__
#define DLIB_TESTING_ReSULTS_H__
#include <dlib/matrix.h>
struct binary_test
{
binary_test() : class1_accuracy(0), class2_accuracy(0) {}
binary_test(
const dlib::matrix<double,1,2>& m
) : class1_accuracy(m(0)),
class2_accuracy(m(1)) {}
double class1_accuracy;
double class2_accuracy;
};
struct regression_test
{
regression_test() : mean_squared_error(0), R_squared(0) {}
regression_test(
const dlib::matrix<double,1,2>& m
) : mean_squared_error(m(0)),
R_squared(m(1)) {}
double mean_squared_error;
double R_squared;
};
struct ranking_test
{
ranking_test() : ranking_accuracy(0), mean_ap(0) {}
ranking_test(
const dlib::matrix<double,1,2>& m
) : ranking_accuracy(m(0)),
mean_ap(m(1)) {}
double ranking_accuracy;
double mean_ap;
};
#endif // DLIB_TESTING_ReSULTS_H__
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment