Commit ebaf7b70 authored by Davis King's avatar Davis King

You shouldn't use a nu of >= to maximum_nu() with the svm_nu_trainer object.

However, this was incorrectly documented as > rather than >= and the code
to detect when a user gave an invalid nu was similarly incorrect.  This
has been fixed.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403393
parent 31c304e1
...@@ -1237,6 +1237,7 @@ namespace dlib ...@@ -1237,6 +1237,7 @@ namespace dlib
long num = (long)std::floor(temp); long num = (long)std::floor(temp);
long num_total = (long)std::ceil(temp); long num_total = (long)std::ceil(temp);
bool has_slack = false;
int count = 0; int count = 0;
for (int i = 0; i < alpha.nr(); ++i) for (int i = 0; i < alpha.nr(); ++i)
{ {
...@@ -1249,6 +1250,7 @@ namespace dlib ...@@ -1249,6 +1250,7 @@ namespace dlib
} }
else else
{ {
has_slack = true;
if (temp > num) if (temp > num)
{ {
++count; ++count;
...@@ -1259,13 +1261,14 @@ namespace dlib ...@@ -1259,13 +1261,14 @@ namespace dlib
} }
} }
if (count != num_total) if (count != num_total || has_slack == false)
{ {
std::ostringstream sout; std::ostringstream sout;
sout << "invalid nu of " << nu << ". Must be between 0 and " << (scalar_type)count/y.nr(); sout << "Invalid nu of " << nu << ". It is required that: 0 < nu < " << 2*(scalar_type)count/y.nr();
throw invalid_svm_nu_error(sout.str(),nu); throw invalid_svm_nu_error(sout.str(),nu);
} }
has_slack = false;
count = 0; count = 0;
for (int i = 0; i < alpha.nr(); ++i) for (int i = 0; i < alpha.nr(); ++i)
{ {
...@@ -1278,6 +1281,7 @@ namespace dlib ...@@ -1278,6 +1281,7 @@ namespace dlib
} }
else else
{ {
has_slack = true;
if (temp > num) if (temp > num)
{ {
++count; ++count;
...@@ -1288,10 +1292,10 @@ namespace dlib ...@@ -1288,10 +1292,10 @@ namespace dlib
} }
} }
if (count != num_total) if (count != num_total || has_slack == false)
{ {
std::ostringstream sout; std::ostringstream sout;
sout << "invalid nu of " << nu << ". Must be between 0 and " << (scalar_type)count/y.nr(); sout << "Invalid nu of " << nu << ". It is required that: 0 < nu < " << 2*(scalar_type)count/y.nr();
throw invalid_svm_nu_error(sout.str(),nu); throw invalid_svm_nu_error(sout.str(),nu);
} }
} }
......
...@@ -246,7 +246,7 @@ namespace dlib ...@@ -246,7 +246,7 @@ namespace dlib
- F(new_x) < 0 - F(new_x) < 0
throws throws
- invalid_svm_nu_error - invalid_svm_nu_error
This exception is thrown if get_nu() > maximum_nu(y) This exception is thrown if get_nu() >= maximum_nu(y)
- std::bad_alloc - std::bad_alloc
!*/ !*/
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment