Commit 9ac9ad9e authored by Davis King's avatar Davis King

Fixed a bug pointed out by Justin Solomon which could cause the svr_trainer and

svm_c_trainer to produce incorrect results in certain unusual cases.  Also added
unit tests to make sure the bug stays fixed.
parent c5a68e71
...@@ -311,13 +311,13 @@ namespace dlib ...@@ -311,13 +311,13 @@ namespace dlib
{ {
if(alpha(i) == Cneg) if(alpha(i) == Cneg)
{ {
if (-df(i) > upper_bound) if (-df(i) < lower_bound)
upper_bound = -df(i); lower_bound = -df(i);
} }
else if(alpha(i) == 0) else if(alpha(i) == 0)
{ {
if (-df(i) < lower_bound) if (-df(i) > upper_bound)
lower_bound = -df(i); upper_bound = -df(i);
} }
else else
{ {
......
...@@ -316,8 +316,8 @@ namespace dlib ...@@ -316,8 +316,8 @@ namespace dlib
long num_free = 0; long num_free = 0;
scalar_type sum_free = 0; scalar_type sum_free = 0;
scalar_type upper_bound; scalar_type upper_bound = -numeric_limits<scalar_type>::infinity();
scalar_type lower_bound; scalar_type lower_bound = numeric_limits<scalar_type>::infinity();
find_min_and_max(df, upper_bound, lower_bound); find_min_and_max(df, upper_bound, lower_bound);
...@@ -345,13 +345,13 @@ namespace dlib ...@@ -345,13 +345,13 @@ namespace dlib
{ {
if(alpha(i) == C) if(alpha(i) == C)
{ {
if (-df(i) > upper_bound) if (-df(i) < lower_bound)
upper_bound = -df(i); lower_bound = -df(i);
} }
else if(alpha(i) == 0) else if(alpha(i) == 0)
{ {
if (-df(i) < lower_bound) if (-df(i) > upper_bound)
lower_bound = -df(i); upper_bound = -df(i);
} }
else else
{ {
......
...@@ -190,7 +190,9 @@ namespace ...@@ -190,7 +190,9 @@ namespace
print_spinner(); print_spinner();
std::vector<sample_type> samples; std::vector<sample_type> samples;
std::vector<sample_type> samples2;
std::vector<double> labels; std::vector<double> labels;
std::vector<double> labels2;
// now we train our object on a few samples of the sinc function. // now we train our object on a few samples of the sinc function.
sample_type m; sample_type m;
for (double x = -10; x <= 5; x += 0.6) for (double x = -10; x <= 5; x += 0.6)
...@@ -199,7 +201,9 @@ namespace ...@@ -199,7 +201,9 @@ namespace
test.train(m, sinc(x)); test.train(m, sinc(x));
samples.push_back(m); samples.push_back(m);
samples2.push_back(m);
labels.push_back(sinc(x)); labels.push_back(sinc(x));
labels2.push_back(2);
} }
print_spinner(); print_spinner();
...@@ -250,6 +254,17 @@ namespace ...@@ -250,6 +254,17 @@ namespace
DLIB_TEST(cv(0) < 1e-4); DLIB_TEST(cv(0) < 1e-4);
DLIB_TEST(cv(1) > 0.99); DLIB_TEST(cv(1) > 0.99);
randomize_samples(samples2, labels2);
dlog << LINFO << "KRR MSE and R-squared: "<< cross_validate_regression_trainer(krr_test, samples2, labels2, 6);
dlog << LINFO << "SVR MSE and R-squared: "<< cross_validate_regression_trainer(svr_test, samples2, labels2, 6);
cv = cross_validate_regression_trainer(krr_test, samples2, labels2, 6);
DLIB_TEST(cv(0) < 1e-4);
cv = cross_validate_regression_trainer(svr_test, samples2, labels2, 6);
DLIB_TEST(cv(0) < 1e-4);
dlog << LINFO << " end test_regression()"; dlog << LINFO << " end test_regression()";
} }
...@@ -569,6 +584,55 @@ namespace ...@@ -569,6 +584,55 @@ namespace
} }
} }
// ----------------------------------------------------------------------------------------
void test_svm_trainer2()
{
typedef matrix<double, 2, 1> sample_type;
typedef linear_kernel<sample_type> kernel_type;
std::vector<sample_type> samples;
std::vector<double> labels;
sample_type samp;
samp(0) = 1;
samp(1) = 1;
samples.push_back(samp);
labels.push_back(+1);
samp(0) = 1;
samp(1) = 2;
samples.push_back(samp);
labels.push_back(-1);
svm_c_trainer<kernel_type> trainer;
decision_function<kernel_type> df = trainer.train(samples, labels);
samp(0) = 1;
samp(1) = 1;
dlog << LINFO << "test +1 : "<< df(samp);
DLIB_TEST(df(samp) > 0);
samp(0) = 1;
samp(1) = 2;
dlog << LINFO << "test -1 : "<< df(samp);
DLIB_TEST(df(samp) < 0);
svm_nu_trainer<kernel_type> trainer2;
df = trainer2.train(samples, labels);
samp(0) = 1;
samp(1) = 1;
dlog << LINFO << "test +1 : "<< df(samp);
DLIB_TEST(df(samp) > 0);
samp(0) = 1;
samp(1) = 2;
dlog << LINFO << "test -1 : "<< df(samp);
DLIB_TEST(df(samp) < 0);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class svm_tester : public tester class svm_tester : public tester
...@@ -588,6 +652,7 @@ namespace ...@@ -588,6 +652,7 @@ namespace
test_clutering(); test_clutering();
test_regression(); test_regression();
test_anomaly_detection(); test_anomaly_detection();
test_svm_trainer2();
} }
} a; } a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment