Commit b8f2b522 authored by Davis King's avatar Davis King

Added decision function testing wrappers.

parent bca5cddf
......@@ -98,6 +98,97 @@ void add_linear_df (
.def_pickle(serialize_pickle<df_type>());
}
// ----------------------------------------------------------------------------------------
struct binary_test
{
binary_test() : class1_accuracy(0), class2_accuracy(0) {}
binary_test(
const matrix<double,1,2>& m
) : class1_accuracy(m(0)),
class2_accuracy(m(1)) {}
double class1_accuracy;
double class2_accuracy;
};
std::string binary_test__str__(const binary_test& item)
{
std::ostringstream sout;
sout << "class1_accuracy: "<< item.class1_accuracy << " class2_accuracy: "<< item.class2_accuracy;
return sout.str();
}
std::string binary_test__repr__(const binary_test& item) { return "< " + binary_test__str__(item) + " >";}
struct regression_test
{
regression_test() : mean_squared_error(0), R_squared(0) {}
regression_test(
const matrix<double,1,2>& m
) : mean_squared_error(m(0)),
R_squared(m(1)) {}
double mean_squared_error;
double R_squared;
};
std::string regression_test__str__(const regression_test& item)
{
std::ostringstream sout;
sout << "mean_squared_error: "<< item.mean_squared_error << " R_squared: "<< item.R_squared;
return sout.str();
}
std::string regression_test__repr__(const regression_test& item) { return "< " + regression_test__str__(item) + " >";}
struct ranking_test
{
ranking_test() : ranking_accuracy(0), mean_ap(0) {}
ranking_test(
const matrix<double,1,2>& m
) : ranking_accuracy(m(0)),
mean_ap(m(1)) {}
double ranking_accuracy;
double mean_ap;
};
std::string ranking_test__str__(const ranking_test& item)
{
std::ostringstream sout;
sout << "ranking_accuracy: "<< item.ranking_accuracy << " mean_ap: "<< item.mean_ap;
return sout.str();
}
std::string ranking_test__repr__(const ranking_test& item) { return "< " + ranking_test__str__(item) + " >";}
// ----------------------------------------------------------------------------------------
template <typename K>
binary_test _test_binary_decision_function (
const decision_function<K>& dec_funct,
const std::vector<typename K::sample_type>& x_test,
const std::vector<double>& y_test
) { return binary_test(test_binary_decision_function(dec_funct, x_test, y_test)); }
template <typename K>
regression_test _test_regression_function (
const decision_function<K>& reg_funct,
const std::vector<typename K::sample_type>& x_test,
const std::vector<double>& y_test
) { return regression_test(test_regression_function(reg_funct, x_test, y_test)); }
template < typename K >
ranking_test _test_ranking_function1 (
const decision_function<K>& funct,
const std::vector<ranking_pair<typename K::sample_type> >& samples
) { return ranking_test(test_ranking_function(funct, samples)); }
template < typename K >
ranking_test _test_ranking_function2 (
const decision_function<K>& funct,
const ranking_pair<typename K::sample_type>& sample
) { return ranking_test(test_ranking_function(funct, sample)); }
void bind_decision_functions()
{
add_linear_df<linear_kernel<sample_type> >("_decision_function_linear");
......@@ -114,6 +205,53 @@ void bind_decision_functions()
add_df<sigmoid_kernel<sample_type> >("_decision_function_sigmoid");
add_df<sparse_sigmoid_kernel<sparse_vect> >("_decision_function_sparse_sigmoid");
def("test_binary_decision_function", _test_binary_decision_function<linear_kernel<sample_type> >);
def("test_binary_decision_function", _test_binary_decision_function<sparse_linear_kernel<sparse_vect> >);
def("test_binary_decision_function", _test_binary_decision_function<radial_basis_kernel<sample_type> >);
def("test_binary_decision_function", _test_binary_decision_function<sparse_radial_basis_kernel<sparse_vect> >);
def("test_binary_decision_function", _test_binary_decision_function<polynomial_kernel<sample_type> >);
def("test_binary_decision_function", _test_binary_decision_function<sparse_polynomial_kernel<sparse_vect> >);
def("test_binary_decision_function", _test_binary_decision_function<histogram_intersection_kernel<sample_type> >);
def("test_binary_decision_function", _test_binary_decision_function<sparse_histogram_intersection_kernel<sparse_vect> >);
def("test_binary_decision_function", _test_binary_decision_function<sigmoid_kernel<sample_type> >);
def("test_binary_decision_function", _test_binary_decision_function<sparse_sigmoid_kernel<sparse_vect> >);
def("test_regression_function", _test_regression_function<linear_kernel<sample_type> >);
def("test_regression_function", _test_regression_function<sparse_linear_kernel<sparse_vect> >);
def("test_regression_function", _test_regression_function<radial_basis_kernel<sample_type> >);
def("test_regression_function", _test_regression_function<sparse_radial_basis_kernel<sparse_vect> >);
def("test_regression_function", _test_regression_function<histogram_intersection_kernel<sample_type> >);
def("test_regression_function", _test_regression_function<sparse_histogram_intersection_kernel<sparse_vect> >);
def("test_regression_function", _test_regression_function<sigmoid_kernel<sample_type> >);
def("test_regression_function", _test_regression_function<sparse_sigmoid_kernel<sparse_vect> >);
def("test_regression_function", _test_regression_function<polynomial_kernel<sample_type> >);
def("test_regression_function", _test_regression_function<sparse_polynomial_kernel<sparse_vect> >);
def("test_ranking_function", _test_ranking_function1<linear_kernel<sample_type> >);
def("test_ranking_function", _test_ranking_function1<sparse_linear_kernel<sparse_vect> >);
def("test_ranking_function", _test_ranking_function2<linear_kernel<sample_type> >);
def("test_ranking_function", _test_ranking_function2<sparse_linear_kernel<sparse_vect> >);
class_<binary_test>("_binary_test")
.add_property("class1_accuracy", &binary_test::class1_accuracy)
.def("__str__", binary_test__str__)
.def("__repr__", binary_test__repr__)
.add_property("class2_accuracy", &binary_test::class2_accuracy);
class_<ranking_test>("_ranking_test")
.add_property("ranking_accuracy", &ranking_test::ranking_accuracy)
.def("__str__", ranking_test__str__)
.def("__repr__", ranking_test__repr__)
.add_property("mean_ap", &ranking_test::mean_ap);
class_<regression_test>("_regression_test")
.add_property("mean_squared_error", &regression_test::mean_squared_error)
.def("__str__", regression_test__str__)
.def("__repr__", regression_test__repr__)
.add_property("R_squared", &regression_test::R_squared);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment