Commit 2ab4add4 authored by Davis King's avatar Davis King

Added some missing validation of the user supplied number of folds

to the cross_validate_multiclass_trainer() routine.  Not it will
throw an exception if the number of folds is too big rather than
just producing a confusing result.
parent 7c436153
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "../matrix.h" #include "../matrix.h"
#include "one_vs_one_trainer.h" #include "one_vs_one_trainer.h"
#include "cross_validate_multiclass_trainer_abstract.h" #include "cross_validate_multiclass_trainer_abstract.h"
#include <sstream>
namespace dlib namespace dlib
{ {
...@@ -67,6 +68,12 @@ namespace dlib ...@@ -67,6 +68,12 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class cross_validation_error : public dlib::error
{
public:
cross_validation_error(const std::string& msg) : dlib::error(msg){};
};
template < template <
typename trainer_type, typename trainer_type,
typename sample_type, typename sample_type,
...@@ -104,6 +111,15 @@ namespace dlib ...@@ -104,6 +111,15 @@ namespace dlib
for (typename std::map<label_type,long>::iterator i = label_counts.begin(); i != label_counts.end(); ++i) for (typename std::map<label_type,long>::iterator i = label_counts.begin(); i != label_counts.end(); ++i)
{ {
const long in_test = i->second/folds; const long in_test = i->second/folds;
if (in_test == 0)
{
std::ostringstream sout;
sout << "In dlib::cross_validate_multiclass_trainer(), the number of folds was larger" << std::endl;
sout << "than the number of elements of one of the training classes." << std::endl;
sout << " folds: "<< folds << std::endl;
sout << " size of class " << i->first << ": "<< i->second << std::endl;
throw cross_validation_error(sout.str());
}
num_in_test[i->first] = in_test; num_in_test[i->first] = in_test;
num_in_train[i->first] = i->second - in_test; num_in_train[i->first] = i->second - in_test;
} }
......
...@@ -38,6 +38,16 @@ namespace dlib ...@@ -38,6 +38,16 @@ namespace dlib
with labels the decision function hasn't ever seen before are ignored. with labels the decision function hasn't ever seen before are ignored.
!*/ !*/
// ----------------------------------------------------------------------------------------
class cross_validation_error : public dlib::error
{
/*!
This is the exception class used by the cross_validate_multiclass_trainer()
routine.
!*/
};
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -74,6 +84,10 @@ namespace dlib ...@@ -74,6 +84,10 @@ namespace dlib
samples in a class is not an even multiple of folds. This is because each fold has the samples in a class is not an even multiple of folds. This is because each fold has the
same number of test samples in it and so if the number of samples in a class isn't a same number of test samples in it and so if the number of samples in a class isn't a
multiple of folds then a few are not tested. multiple of folds then a few are not tested.
throws
- cross_validation_error
This exception is thrown if one of the classes has fewer samples than
the number of requested folds.
!*/ !*/
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment