Commit 5222b168 authored by Davis King's avatar Davis King

Added unit tests for the svr_trainer and svm_one_class_trainer.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404027
parent 525f3a8b
...@@ -175,8 +175,15 @@ namespace ...@@ -175,8 +175,15 @@ namespace
krls<kernel_type> test(kernel_type(0.1),0.001); krls<kernel_type> test(kernel_type(0.1),0.001);
rvm_regression_trainer<kernel_type> rvm_test; rvm_regression_trainer<kernel_type> rvm_test;
rvm_test.set_kernel(test.get_kernel()); rvm_test.set_kernel(test.get_kernel());
krr_trainer<kernel_type> krr_test; krr_trainer<kernel_type> krr_test;
krr_test.set_kernel(test.get_kernel()); krr_test.set_kernel(test.get_kernel());
svr_trainer<kernel_type> svr_test;
svr_test.set_kernel(test.get_kernel());
svr_test.set_epsilon_insensitivity(0.0001);
svr_test.set_c(10);
rbf_network_trainer<kernel_type> rbf_test; rbf_network_trainer<kernel_type> rbf_test;
rbf_test.set_kernel(test.get_kernel()); rbf_test.set_kernel(test.get_kernel());
rbf_test.set_num_centers(13); rbf_test.set_num_centers(13);
...@@ -202,6 +209,8 @@ namespace ...@@ -202,6 +209,8 @@ namespace
print_spinner(); print_spinner();
decision_function<kernel_type> test4 = krr_test.train(samples, labels); decision_function<kernel_type> test4 = krr_test.train(samples, labels);
print_spinner(); print_spinner();
decision_function<kernel_type> test5 = svr_test.train(samples, labels);
print_spinner();
// now we output the value of the sinc function for a few test points as well as the // now we output the value of the sinc function for a few test points as well as the
// value predicted by krls object. // value predicted by krls object.
...@@ -224,6 +233,11 @@ namespace ...@@ -224,6 +233,11 @@ namespace
m(0) = 0.1; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01); m(0) = 0.1; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01);
m(0) = -4; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01); m(0) = -4; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01);
m(0) = 5.0; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01); m(0) = 5.0; dlog << LDEBUG << "krr: " << sinc(m(0)) << " " << test4(m); DLIB_TEST(abs(sinc(m(0)) - test4(m)) < 0.01);
m(0) = 2.5; dlog << LDEBUG << "svr: " << sinc(m(0)) << " " << test5(m); DLIB_TEST(abs(sinc(m(0)) - test5(m)) < 0.01);
m(0) = 0.1; dlog << LDEBUG << "svr: " << sinc(m(0)) << " " << test5(m); DLIB_TEST(abs(sinc(m(0)) - test5(m)) < 0.01);
m(0) = -4; dlog << LDEBUG << "svr: " << sinc(m(0)) << " " << test5(m); DLIB_TEST(abs(sinc(m(0)) - test5(m)) < 0.01);
m(0) = 5.0; dlog << LDEBUG << "svr: " << sinc(m(0)) << " " << test5(m); DLIB_TEST(abs(sinc(m(0)) - test5(m)) < 0.01);
dlog << LINFO << " end test_regression()"; dlog << LINFO << " end test_regression()";
} }
...@@ -249,6 +263,13 @@ namespace ...@@ -249,6 +263,13 @@ namespace
// Here we have set it to 0.01. // Here we have set it to 0.01.
kcentroid<kernel_type> test(kernel_type(0.1),0.01); kcentroid<kernel_type> test(kernel_type(0.1),0.01);
svm_one_class_trainer<kernel_type> one_class_trainer;
one_class_trainer.set_nu(0.4);
one_class_trainer.set_kernel(kernel_type(0.2));
std::vector<sample_type> samples;
// now we train our object on a few samples of the sinc function. // now we train our object on a few samples of the sinc function.
sample_type m; sample_type m;
for (double x = -15; x <= 8; x += 1) for (double x = -15; x <= 8; x += 1)
...@@ -256,8 +277,11 @@ namespace ...@@ -256,8 +277,11 @@ namespace
m(0) = x; m(0) = x;
m(1) = sinc(x); m(1) = sinc(x);
test.train(m); test.train(m);
samples.push_back(m);
} }
decision_function<kernel_type> df = one_class_trainer.train(samples);
running_stats<double> rs; running_stats<double> rs;
// Now lets output the distance from the centroid to some points that are from the sinc function. // Now lets output the distance from the centroid to some points that are from the sinc function.
...@@ -281,6 +305,15 @@ namespace ...@@ -281,6 +305,15 @@ namespace
m(0) = -1.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m))); m(0) = -1.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m)));
m(0) = -0.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m))); m(0) = -0.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(rs.scale(test(m)) < 2, rs.scale(test(m)));
const double thresh = 0.01;
m(0) = -1.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m));
m(0) = -1.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m));
m(0) = -0; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m));
m(0) = -0.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m));
m(0) = -4.1; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m));
m(0) = -1.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m));
m(0) = -0.5; m(1) = sinc(m(0)); DLIB_TEST_MSG(df(m)+thresh > 0, df(m));
dlog << LDEBUG; dlog << LDEBUG;
// Lets output the distance from the centroid to some points that are NOT from the sinc function. // Lets output the distance from the centroid to some points that are NOT from the sinc function.
// These numbers should all be significantly bigger than previous set of numbers. We will also // These numbers should all be significantly bigger than previous set of numbers. We will also
...@@ -291,30 +324,37 @@ namespace ...@@ -291,30 +324,37 @@ namespace
m(0) = -1.5; m(1) = sinc(m(0))+4; m(0) = -1.5; m(1) = sinc(m(0))+4;
dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc."; dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc.";
DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m))); DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m)));
DLIB_TEST_MSG(df(m) + thresh < 0, df(m));
m(0) = -1.5; m(1) = sinc(m(0))+3; m(0) = -1.5; m(1) = sinc(m(0))+3;
dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc."; dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc.";
DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m))); DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m)));
DLIB_TEST_MSG(df(m) + thresh < 0, df(m));
m(0) = -0; m(1) = -sinc(m(0)); m(0) = -0; m(1) = -sinc(m(0));
dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc."; dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc.";
DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m))); DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m)));
DLIB_TEST_MSG(df(m) + thresh < 0, df(m));
m(0) = -0.5; m(1) = -sinc(m(0)); m(0) = -0.5; m(1) = -sinc(m(0));
dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc."; dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc.";
DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m))); DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m)));
DLIB_TEST_MSG(df(m) + thresh < 0, df(m));
m(0) = -4.1; m(1) = sinc(m(0))+2; m(0) = -4.1; m(1) = sinc(m(0))+2;
dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc."; dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc.";
DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m))); DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m)));
DLIB_TEST_MSG(df(m) + thresh < 0, df(m));
m(0) = -1.5; m(1) = sinc(m(0))+0.9; m(0) = -1.5; m(1) = sinc(m(0))+0.9;
dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc."; dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc.";
DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m))); DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m)));
DLIB_TEST_MSG(df(m) + thresh < 0, df(m));
m(0) = -0.5; m(1) = sinc(m(0))+1; m(0) = -0.5; m(1) = sinc(m(0))+1;
dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc."; dlog << LDEBUG << " " << test(m) << " is " << rs.scale(test(m)) << " standard deviations from sinc.";
DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m))); DLIB_TEST_MSG(rs.scale(test(m)) > 6, rs.scale(test(m)));
DLIB_TEST_MSG(df(m) + thresh < 0, df(m));
dlog << LINFO << " end test_anomaly_detection()"; dlog << LINFO << " end test_anomaly_detection()";
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment