Commit 65c1551a authored by Davis King's avatar Davis King

Added some tests for the new krr_trainer. I also simplified the checkers board

dataset used to test the classifiers a little so that the test runs faster.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403760
parent 12a3b2a8
......@@ -175,6 +175,8 @@ namespace
krls<kernel_type> test(kernel_type(0.1),0.001);
rvm_regression_trainer<kernel_type> rvm_test;
rvm_test.set_kernel(test.get_kernel());
krr_trainer<kernel_type> krr_test;
krr_test.set_kernel(test.get_kernel());
rbf_network_trainer<kernel_type> rbf_test;
rbf_test.set_kernel(test.get_kernel());
rbf_test.set_num_centers(13);
......@@ -198,6 +200,8 @@ namespace
print_spinner();
decision_function<kernel_type> test3 = rbf_test.train(samples, labels);
print_spinner();
decision_function<kernel_type> test4 = krr_test.train(samples, labels);
print_spinner();
// now we output the value of the sinc function for a few test points as well as the
// value predicted by krls object.
......@@ -215,6 +219,11 @@ namespace
m(0) = 0.1; dlog << LDEBUG << "rbf: " << sinc(m(0)) << " " << test3(m); DLIB_TEST(abs(sinc(m(0)) - test3(m)) < 0.01);
m(0) = -4; dlog << LDEBUG << "rbf: " << sinc(m(0)) << " " << test3(m); DLIB_TEST(abs(sinc(m(0)) - test3(m)) < 0.01);
m(0) = 5.0; dlog << LDEBUG << "rbf: " << sinc(m(0)) << " " << test3(m); DLIB_TEST(abs(sinc(m(0)) - test3(m)) < 0.01);
m(0) = 2.5; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01);
m(0) = 0.1; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01);
m(0) = -4; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01);
m(0) = 5.0; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01);
dlog << LINFO << " end test_regression()";
}
......@@ -330,7 +339,7 @@ namespace
std::vector<matrix<double,0,1> > x_linearized;
std::vector<scalar_type> y;
get_checkerboard_problem(x,y, 300, 3);
get_checkerboard_problem(x,y, 300, 2);
const scalar_type gamma = 1;
typedef radial_basis_kernel<sample_type> kernel_type;
......@@ -342,6 +351,10 @@ namespace
rvm_trainer<kernel_type> rvm_trainer;
rvm_trainer.set_kernel(kernel_type(gamma));
krr_trainer<kernel_type> krr_trainer;
krr_trainer.estimate_lambda_for_classification();
krr_trainer.set_kernel(kernel_type(gamma));
svm_pegasos<kernel_type> pegasos_trainer;
pegasos_trainer.set_kernel(kernel_type(gamma));
pegasos_trainer.set_lambda(0.00001);
......@@ -367,6 +380,8 @@ namespace
print_spinner();
matrix<scalar_type> rvm_cv = cross_validate_trainer_threaded(rvm_trainer, x,y, 4, 2);
print_spinner();
matrix<scalar_type> krr_cv = cross_validate_trainer_threaded(krr_trainer, x,y, 4, 2);
print_spinner();
matrix<scalar_type> svm_cv = cross_validate_trainer(trainer, x,y, 4);
print_spinner();
matrix<scalar_type> rbf_cv = cross_validate_trainer_threaded(rbf_trainer, x,y, 10, 2);
......@@ -384,6 +399,7 @@ namespace
print_spinner();
dlog << LDEBUG << "rvm cv: " << rvm_cv;
dlog << LDEBUG << "krr cv: " << krr_cv;
dlog << LDEBUG << "svm cv: " << svm_cv;
dlog << LDEBUG << "rbf cv: " << rbf_cv;
dlog << LDEBUG << "lin cv: " << lin_cv;
......@@ -397,6 +413,7 @@ namespace
sum(abs(peg_cv - peg_c_cv)) << " \n" << peg_cv << peg_c_cv );
DLIB_TEST_MSG(mean(rvm_cv) > 0.9, rvm_cv);
DLIB_TEST_MSG(mean(krr_cv) > 0.9, krr_cv);
DLIB_TEST_MSG(mean(svm_cv) > 0.9, svm_cv);
DLIB_TEST_MSG(mean(rbf_cv) > 0.9, rbf_cv);
DLIB_TEST_MSG(mean(lin_cv) > 0.9, lin_cv);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment