Commit 269ba90a authored by Davis King's avatar Davis King

Made cross_validate_trainer() and cross_validate_trainer_threaded() not make

duplicate copies of the training data since doing so uses a lot of RAM for
large datasets.
parent 3fe820cc
...@@ -447,7 +447,7 @@ namespace dlib ...@@ -447,7 +447,7 @@ namespace dlib
const long num_neg_train_samples = num_neg - num_neg_test_samples; const long num_neg_train_samples = num_neg - num_neg_test_samples;
sample_vector_type x_test, x_train; matrix<long,0,1> x_test, x_train;
scalar_vector_type y_test, y_train; scalar_vector_type y_test, y_train;
x_test.set_size (num_pos_test_samples + num_neg_test_samples); x_test.set_size (num_pos_test_samples + num_neg_test_samples);
y_test.set_size (num_pos_test_samples + num_neg_test_samples); y_test.set_size (num_pos_test_samples + num_neg_test_samples);
...@@ -469,7 +469,7 @@ namespace dlib ...@@ -469,7 +469,7 @@ namespace dlib
{ {
if (y(pos_idx) == +1.0) if (y(pos_idx) == +1.0)
{ {
x_test(cur) = x(pos_idx); x_test(cur) = pos_idx;
y_test(cur) = +1.0; y_test(cur) = +1.0;
++cur; ++cur;
} }
...@@ -481,7 +481,7 @@ namespace dlib ...@@ -481,7 +481,7 @@ namespace dlib
{ {
if (y(neg_idx) == -1.0) if (y(neg_idx) == -1.0)
{ {
x_test(cur) = x(neg_idx); x_test(cur) = neg_idx;
y_test(cur) = -1.0; y_test(cur) = -1.0;
++cur; ++cur;
} }
...@@ -499,7 +499,7 @@ namespace dlib ...@@ -499,7 +499,7 @@ namespace dlib
{ {
if (y(train_pos_idx) == +1.0) if (y(train_pos_idx) == +1.0)
{ {
x_train(cur) = x(train_pos_idx); x_train(cur) = train_pos_idx;
y_train(cur) = +1.0; y_train(cur) = +1.0;
++cur; ++cur;
} }
...@@ -511,7 +511,7 @@ namespace dlib ...@@ -511,7 +511,7 @@ namespace dlib
{ {
if (y(train_neg_idx) == -1.0) if (y(train_neg_idx) == -1.0)
{ {
x_train(cur) = x(train_neg_idx); x_train(cur) = train_neg_idx;
y_train(cur) = -1.0; y_train(cur) = -1.0;
++cur; ++cur;
} }
...@@ -521,7 +521,7 @@ namespace dlib ...@@ -521,7 +521,7 @@ namespace dlib
try try
{ {
// do the training and testing // do the training and testing
res += test_binary_decision_function(trainer.train(x_train,y_train),x_test,y_test); res += test_binary_decision_function(trainer.train(rowm(x,x_train),y_train),rowm(x,x_test),y_test);
} }
catch (invalid_nu_error&) catch (invalid_nu_error&)
{ {
......
...@@ -26,7 +26,7 @@ namespace dlib ...@@ -26,7 +26,7 @@ namespace dlib
namespace cvtti_helpers namespace cvtti_helpers
{ {
template <typename trainer_type> template <typename trainer_type, typename in_sample_vector_type>
struct job struct job
{ {
typedef typename trainer_type::scalar_type scalar_type; typedef typename trainer_type::scalar_type scalar_type;
...@@ -35,29 +35,33 @@ namespace dlib ...@@ -35,29 +35,33 @@ namespace dlib
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type; typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type; typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
job() : x(0) {}
trainer_type trainer; trainer_type trainer;
sample_vector_type x_test, x_train; matrix<long,0,1> x_test, x_train;
scalar_vector_type y_test, y_train; scalar_vector_type y_test, y_train;
const in_sample_vector_type* x;
}; };
struct task struct task
{ {
template < template <
typename trainer_type, typename trainer_type,
typename matrix_type typename matrix_type,
typename in_sample_vector_type
> >
void operator()( void operator()(
job<trainer_type>& j, job<trainer_type,in_sample_vector_type>& j,
matrix_type& result matrix_type& result
) )
{ {
try try
{ {
result = test_binary_decision_function(j.trainer.train(j.x_train, j.y_train), j.x_test, j.y_test); result = test_binary_decision_function(j.trainer.train(rowm(*j.x,j.x_train), j.y_train), rowm(*j.x,j.x_test), j.y_test);
// Do this just to make j release it's memory since people might run threaded cross validation // Do this just to make j release it's memory since people might run threaded cross validation
// on very large datasets. Every bit of freed memory helps out. // on very large datasets. Every bit of freed memory helps out.
j = job<trainer_type>(); j = job<trainer_type,in_sample_vector_type>();
} }
catch (invalid_nu_error&) catch (invalid_nu_error&)
{ {
...@@ -132,14 +136,15 @@ namespace dlib ...@@ -132,14 +136,15 @@ namespace dlib
std::vector<future<job<trainer_type> > > jobs(folds); std::vector<future<job<trainer_type,in_sample_vector_type> > > jobs(folds);
std::vector<future<matrix<scalar_type, 1, 2, mem_manager_type> > > results(folds); std::vector<future<matrix<scalar_type, 1, 2, mem_manager_type> > > results(folds);
for (long i = 0; i < folds; ++i) for (long i = 0; i < folds; ++i)
{ {
job<trainer_type>& j = jobs[i].get(); job<trainer_type,in_sample_vector_type>& j = jobs[i].get();
j.x = &x;
j.x_test.set_size (num_pos_test_samples + num_neg_test_samples); j.x_test.set_size (num_pos_test_samples + num_neg_test_samples);
j.y_test.set_size (num_pos_test_samples + num_neg_test_samples); j.y_test.set_size (num_pos_test_samples + num_neg_test_samples);
j.x_train.set_size(num_pos_train_samples + num_neg_train_samples); j.x_train.set_size(num_pos_train_samples + num_neg_train_samples);
...@@ -153,7 +158,7 @@ namespace dlib ...@@ -153,7 +158,7 @@ namespace dlib
{ {
if (y(pos_idx) == +1.0) if (y(pos_idx) == +1.0)
{ {
j.x_test(cur) = x(pos_idx); j.x_test(cur) = pos_idx;
j.y_test(cur) = +1.0; j.y_test(cur) = +1.0;
++cur; ++cur;
} }
...@@ -165,7 +170,7 @@ namespace dlib ...@@ -165,7 +170,7 @@ namespace dlib
{ {
if (y(neg_idx) == -1.0) if (y(neg_idx) == -1.0)
{ {
j.x_test(cur) = x(neg_idx); j.x_test(cur) = neg_idx;
j.y_test(cur) = -1.0; j.y_test(cur) = -1.0;
++cur; ++cur;
} }
...@@ -183,7 +188,7 @@ namespace dlib ...@@ -183,7 +188,7 @@ namespace dlib
{ {
if (y(train_pos_idx) == +1.0) if (y(train_pos_idx) == +1.0)
{ {
j.x_train(cur) = x(train_pos_idx); j.x_train(cur) = train_pos_idx;
j.y_train(cur) = +1.0; j.y_train(cur) = +1.0;
++cur; ++cur;
} }
...@@ -195,7 +200,7 @@ namespace dlib ...@@ -195,7 +200,7 @@ namespace dlib
{ {
if (y(train_neg_idx) == -1.0) if (y(train_neg_idx) == -1.0)
{ {
j.x_train(cur) = x(train_neg_idx); j.x_train(cur) = train_neg_idx;
j.y_train(cur) = -1.0; j.y_train(cur) = -1.0;
++cur; ++cur;
} }
......
...@@ -32,6 +32,9 @@ namespace dlib ...@@ -32,6 +32,9 @@ namespace dlib
(e.g. There must be at least as many examples of each class as there are folds) (e.g. There must be at least as many examples of each class as there are folds)
- trainer_type == some kind of trainer object (e.g. svm_nu_trainer) - trainer_type == some kind of trainer object (e.g. svm_nu_trainer)
- num_threads > 0 - num_threads > 0
- It must be safe for multiple trainer objects to access the elements of x from
multiple threads at the same time. Note that all trainers and kernels in
dlib are thread safe in this regard since they do not mutate the elements of x.
ensures ensures
- performs k-fold cross validation by using the given trainer to solve the - performs k-fold cross validation by using the given trainer to solve the
given binary classification problem for the given number of folds. given binary classification problem for the given number of folds.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment