Commit 269ba90a authored by Davis King's avatar Davis King

Made cross_validate_trainer() and cross_validate_trainer_threaded() not make

duplicate copies of the training data since doing so uses a lot of RAM for
large datasets.
parent 3fe820cc
......@@ -447,7 +447,7 @@ namespace dlib
const long num_neg_train_samples = num_neg - num_neg_test_samples;
sample_vector_type x_test, x_train;
matrix<long,0,1> x_test, x_train;
scalar_vector_type y_test, y_train;
x_test.set_size (num_pos_test_samples + num_neg_test_samples);
y_test.set_size (num_pos_test_samples + num_neg_test_samples);
......@@ -469,7 +469,7 @@ namespace dlib
{
if (y(pos_idx) == +1.0)
{
x_test(cur) = x(pos_idx);
x_test(cur) = pos_idx;
y_test(cur) = +1.0;
++cur;
}
......@@ -481,7 +481,7 @@ namespace dlib
{
if (y(neg_idx) == -1.0)
{
x_test(cur) = x(neg_idx);
x_test(cur) = neg_idx;
y_test(cur) = -1.0;
++cur;
}
......@@ -499,7 +499,7 @@ namespace dlib
{
if (y(train_pos_idx) == +1.0)
{
x_train(cur) = x(train_pos_idx);
x_train(cur) = train_pos_idx;
y_train(cur) = +1.0;
++cur;
}
......@@ -511,7 +511,7 @@ namespace dlib
{
if (y(train_neg_idx) == -1.0)
{
x_train(cur) = x(train_neg_idx);
x_train(cur) = train_neg_idx;
y_train(cur) = -1.0;
++cur;
}
......@@ -521,7 +521,7 @@ namespace dlib
try
{
// do the training and testing
res += test_binary_decision_function(trainer.train(x_train,y_train),x_test,y_test);
res += test_binary_decision_function(trainer.train(rowm(x,x_train),y_train),rowm(x,x_test),y_test);
}
catch (invalid_nu_error&)
{
......
......@@ -26,7 +26,7 @@ namespace dlib
namespace cvtti_helpers
{
template <typename trainer_type>
template <typename trainer_type, typename in_sample_vector_type>
struct job
{
typedef typename trainer_type::scalar_type scalar_type;
......@@ -35,29 +35,33 @@ namespace dlib
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
job() : x(0) {}
trainer_type trainer;
sample_vector_type x_test, x_train;
matrix<long,0,1> x_test, x_train;
scalar_vector_type y_test, y_train;
const in_sample_vector_type* x;
};
struct task
{
template <
typename trainer_type,
typename matrix_type
typename matrix_type,
typename in_sample_vector_type
>
void operator()(
job<trainer_type>& j,
job<trainer_type,in_sample_vector_type>& j,
matrix_type& result
)
{
try
{
result = test_binary_decision_function(j.trainer.train(j.x_train, j.y_train), j.x_test, j.y_test);
result = test_binary_decision_function(j.trainer.train(rowm(*j.x,j.x_train), j.y_train), rowm(*j.x,j.x_test), j.y_test);
// Do this just to make j release it's memory since people might run threaded cross validation
// on very large datasets. Every bit of freed memory helps out.
j = job<trainer_type>();
j = job<trainer_type,in_sample_vector_type>();
}
catch (invalid_nu_error&)
{
......@@ -132,14 +136,15 @@ namespace dlib
std::vector<future<job<trainer_type> > > jobs(folds);
std::vector<future<job<trainer_type,in_sample_vector_type> > > jobs(folds);
std::vector<future<matrix<scalar_type, 1, 2, mem_manager_type> > > results(folds);
for (long i = 0; i < folds; ++i)
{
job<trainer_type>& j = jobs[i].get();
job<trainer_type,in_sample_vector_type>& j = jobs[i].get();
j.x = &x;
j.x_test.set_size (num_pos_test_samples + num_neg_test_samples);
j.y_test.set_size (num_pos_test_samples + num_neg_test_samples);
j.x_train.set_size(num_pos_train_samples + num_neg_train_samples);
......@@ -153,7 +158,7 @@ namespace dlib
{
if (y(pos_idx) == +1.0)
{
j.x_test(cur) = x(pos_idx);
j.x_test(cur) = pos_idx;
j.y_test(cur) = +1.0;
++cur;
}
......@@ -165,7 +170,7 @@ namespace dlib
{
if (y(neg_idx) == -1.0)
{
j.x_test(cur) = x(neg_idx);
j.x_test(cur) = neg_idx;
j.y_test(cur) = -1.0;
++cur;
}
......@@ -183,7 +188,7 @@ namespace dlib
{
if (y(train_pos_idx) == +1.0)
{
j.x_train(cur) = x(train_pos_idx);
j.x_train(cur) = train_pos_idx;
j.y_train(cur) = +1.0;
++cur;
}
......@@ -195,7 +200,7 @@ namespace dlib
{
if (y(train_neg_idx) == -1.0)
{
j.x_train(cur) = x(train_neg_idx);
j.x_train(cur) = train_neg_idx;
j.y_train(cur) = -1.0;
++cur;
}
......
......@@ -32,6 +32,9 @@ namespace dlib
(e.g. There must be at least as many examples of each class as there are folds)
- trainer_type == some kind of trainer object (e.g. svm_nu_trainer)
- num_threads > 0
- It must be safe for multiple trainer objects to access the elements of x from
multiple threads at the same time. Note that all trainers and kernels in
dlib are thread safe in this regard since they do not mutate the elements of x.
ensures
- performs k-fold cross validation by using the given trainer to solve the
given binary classification problem for the given number of folds.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment