Commit af20cb9f authored by Davis King's avatar Davis King

Fixed a bug in the cross_validate_trainer_threaded() function. It could deadlock if

more than about 10 folds were requested.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403137
parent 2e66a59d
...@@ -39,68 +39,26 @@ namespace dlib ...@@ -39,68 +39,26 @@ namespace dlib
scalar_vector_type y_test, y_train; scalar_vector_type y_test, y_train;
}; };
template <typename trainer_type> struct task
void swap(
job<trainer_type>& a,
job<trainer_type>& b
)
{
exchange(a.trainer, b.trainer);
exchange(a.x_test, b.x_test);
exchange(a.y_test, b.y_test);
exchange(a.x_train, b.x_train);
exchange(a.y_train, b.y_train);
}
template <typename trainer_type>
class a_thread : multithreaded_object
{ {
public: template <
typedef typename trainer_type::scalar_type scalar_type; typename trainer_type,
typedef typename trainer_type::sample_type sample_type; typename matrix_type
typedef typename trainer_type::mem_manager_type mem_manager_type; >
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type; void operator()(
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type; const job<trainer_type>& j,
matrix_type& result
explicit a_thread( long num_threads) : job_pipe(1), res_pipe(3) )
{ {
for (long i = 0; i < num_threads; ++i) try
{ {
register_thread(*this, &a_thread::thread); result = test_binary_decision_function(j.trainer.train(j.x_train, j.y_train), j.x_test, j.y_test);
} }
start(); catch (invalid_svm_nu_error&)
}
~a_thread()
{
// disable the job_pipe so that the threads will unblock and terminate
job_pipe.disable();
wait();
}
typename pipe<job<trainer_type> > ::kernel_1a job_pipe;
typename pipe<matrix<scalar_type, 1, 2, mem_manager_type> >::kernel_1a res_pipe;
private:
void thread()
{
job<trainer_type> j;
matrix<scalar_type, 1, 2, mem_manager_type> temp_res;
while (job_pipe.dequeue(j))
{ {
try // If this is a svm_nu_trainer then we might get this exception if the nu is
{ // invalid. In this case just return a cross validation score of 0.
temp_res = test_binary_decision_function(j.trainer.train(j.x_train, j.y_train), j.x_test, j.y_test); result = 0;
}
catch (invalid_svm_nu_error&)
{
// If this is a svm_nu_trainer then we might get this exception if the nu is
// invalid. In this case just return a cross validation score of 0.
temp_res = 0;
}
res_pipe.enqueue(temp_res);
} }
} }
}; };
...@@ -140,6 +98,10 @@ namespace dlib ...@@ -140,6 +98,10 @@ namespace dlib
); );
task mytask;
thread_pool tp(num_threads);
// count the number of positive and negative examples // count the number of positive and negative examples
long num_pos = 0; long num_pos = 0;
long num_neg = 0; long num_neg = 0;
...@@ -158,18 +120,19 @@ namespace dlib ...@@ -158,18 +120,19 @@ namespace dlib
const long num_neg_train_samples = num_neg - num_neg_test_samples; const long num_neg_train_samples = num_neg - num_neg_test_samples;
typename trainer_type::trained_function_type d;
long pos_idx = 0; long pos_idx = 0;
long neg_idx = 0; long neg_idx = 0;
job<trainer_type> j; std::vector<future<job<trainer_type> > > jobs(folds);
a_thread<trainer_type> threads(num_threads); std::vector<future<matrix<scalar_type, 1, 2, mem_manager_type> > > results(folds);
for (long i = 0; i < folds; ++i) for (long i = 0; i < folds; ++i)
{ {
job<trainer_type>& j = jobs[i].get();
j.x_test.set_size (num_pos_test_samples + num_neg_test_samples); j.x_test.set_size (num_pos_test_samples + num_neg_test_samples);
j.y_test.set_size (num_pos_test_samples + num_neg_test_samples); j.y_test.set_size (num_pos_test_samples + num_neg_test_samples);
j.x_train.set_size(num_pos_train_samples + num_neg_train_samples); j.x_train.set_size(num_pos_train_samples + num_neg_train_samples);
...@@ -232,21 +195,18 @@ namespace dlib ...@@ -232,21 +195,18 @@ namespace dlib
train_neg_idx = (train_neg_idx+1)%x.nr(); train_neg_idx = (train_neg_idx+1)%x.nr();
} }
// add this job to the job pipe so that the threads // finally spawn a task to process this job
// will process it tp.add_task(mytask, jobs[i], results[i]);
threads.job_pipe.enqueue(j);
} // for (long i = 0; i < folds; ++i) } // for (long i = 0; i < folds; ++i)
matrix<scalar_type, 1, 2, mem_manager_type> res; matrix<scalar_type, 1, 2, mem_manager_type> res;
matrix<scalar_type, 1, 2, mem_manager_type> temp_res;
set_all_elements(res,0); set_all_elements(res,0);
// now wait for the threads to finish // now compute the total results
for (long i = 0; i < folds; ++i) for (long i = 0; i < folds; ++i)
{ {
threads.res_pipe.dequeue(temp_res); res += results[i].get();
res += temp_res;
} }
return res/(scalar_type)folds; return res/(scalar_type)folds;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment