Commit 9ac9ad9e authored by Davis King's avatar Davis King

Fixed a bug pointed out by Justin Solomon which could cause the svr_trainer and

svm_c_trainer to produce incorrect results in certain unusual cases.  Also added
unit tests to make sure the bug stays fixed.
parent c5a68e71
......@@ -311,13 +311,13 @@ namespace dlib
{
if(alpha(i) == Cneg)
{
if (-df(i) > upper_bound)
upper_bound = -df(i);
if (-df(i) < lower_bound)
lower_bound = -df(i);
}
else if(alpha(i) == 0)
{
if (-df(i) < lower_bound)
lower_bound = -df(i);
if (-df(i) > upper_bound)
upper_bound = -df(i);
}
else
{
......
......@@ -316,8 +316,8 @@ namespace dlib
long num_free = 0;
scalar_type sum_free = 0;
scalar_type upper_bound;
scalar_type lower_bound;
scalar_type upper_bound = -numeric_limits<scalar_type>::infinity();
scalar_type lower_bound = numeric_limits<scalar_type>::infinity();
find_min_and_max(df, upper_bound, lower_bound);
......@@ -345,13 +345,13 @@ namespace dlib
{
if(alpha(i) == C)
{
if (-df(i) > upper_bound)
upper_bound = -df(i);
if (-df(i) < lower_bound)
lower_bound = -df(i);
}
else if(alpha(i) == 0)
{
if (-df(i) < lower_bound)
lower_bound = -df(i);
if (-df(i) > upper_bound)
upper_bound = -df(i);
}
else
{
......
......@@ -190,7 +190,9 @@ namespace
print_spinner();
std::vector<sample_type> samples;
std::vector<sample_type> samples2;
std::vector<double> labels;
std::vector<double> labels2;
// now we train our object on a few samples of the sinc function.
sample_type m;
for (double x = -10; x <= 5; x += 0.6)
......@@ -199,7 +201,9 @@ namespace
test.train(m, sinc(x));
samples.push_back(m);
samples2.push_back(m);
labels.push_back(sinc(x));
labels2.push_back(2);
}
print_spinner();
......@@ -250,6 +254,17 @@ namespace
DLIB_TEST(cv(0) < 1e-4);
DLIB_TEST(cv(1) > 0.99);
randomize_samples(samples2, labels2);
dlog << LINFO << "KRR MSE and R-squared: "<< cross_validate_regression_trainer(krr_test, samples2, labels2, 6);
dlog << LINFO << "SVR MSE and R-squared: "<< cross_validate_regression_trainer(svr_test, samples2, labels2, 6);
cv = cross_validate_regression_trainer(krr_test, samples2, labels2, 6);
DLIB_TEST(cv(0) < 1e-4);
cv = cross_validate_regression_trainer(svr_test, samples2, labels2, 6);
DLIB_TEST(cv(0) < 1e-4);
dlog << LINFO << " end test_regression()";
}
......@@ -569,6 +584,55 @@ namespace
}
}
// ----------------------------------------------------------------------------------------
void test_svm_trainer2()
{
typedef matrix<double, 2, 1> sample_type;
typedef linear_kernel<sample_type> kernel_type;
std::vector<sample_type> samples;
std::vector<double> labels;
sample_type samp;
samp(0) = 1;
samp(1) = 1;
samples.push_back(samp);
labels.push_back(+1);
samp(0) = 1;
samp(1) = 2;
samples.push_back(samp);
labels.push_back(-1);
svm_c_trainer<kernel_type> trainer;
decision_function<kernel_type> df = trainer.train(samples, labels);
samp(0) = 1;
samp(1) = 1;
dlog << LINFO << "test +1 : "<< df(samp);
DLIB_TEST(df(samp) > 0);
samp(0) = 1;
samp(1) = 2;
dlog << LINFO << "test -1 : "<< df(samp);
DLIB_TEST(df(samp) < 0);
svm_nu_trainer<kernel_type> trainer2;
df = trainer2.train(samples, labels);
samp(0) = 1;
samp(1) = 1;
dlog << LINFO << "test +1 : "<< df(samp);
DLIB_TEST(df(samp) > 0);
samp(0) = 1;
samp(1) = 2;
dlog << LINFO << "test -1 : "<< df(samp);
DLIB_TEST(df(samp) < 0);
}
// ----------------------------------------------------------------------------------------
class svm_tester : public tester
......@@ -588,6 +652,7 @@ namespace
test_clutering();
test_regression();
test_anomaly_detection();
test_svm_trainer2();
}
} a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment