Commit af20cb9f authored by Davis King's avatar Davis King

Fixed a bug in the cross_validate_trainer_threaded() function. It could deadlock if

more than about 10 folds were requested.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403137
parent 2e66a59d
...@@ -39,68 +39,26 @@ namespace dlib ...@@ -39,68 +39,26 @@ namespace dlib
scalar_vector_type y_test, y_train; scalar_vector_type y_test, y_train;
}; };
template <typename trainer_type> struct task
void swap(
job<trainer_type>& a,
job<trainer_type>& b
)
{
exchange(a.trainer, b.trainer);
exchange(a.x_test, b.x_test);
exchange(a.y_test, b.y_test);
exchange(a.x_train, b.x_train);
exchange(a.y_train, b.y_train);
}
template <typename trainer_type>
class a_thread : multithreaded_object
{
public:
typedef typename trainer_type::scalar_type scalar_type;
typedef typename trainer_type::sample_type sample_type;
typedef typename trainer_type::mem_manager_type mem_manager_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
explicit a_thread( long num_threads) : job_pipe(1), res_pipe(3)
{
for (long i = 0; i < num_threads; ++i)
{
register_thread(*this, &a_thread::thread);
}
start();
}
~a_thread()
{ {
// disable the job_pipe so that the threads will unblock and terminate template <
job_pipe.disable(); typename trainer_type,
wait(); typename matrix_type
} >
void operator()(
typename pipe<job<trainer_type> > ::kernel_1a job_pipe; const job<trainer_type>& j,
typename pipe<matrix<scalar_type, 1, 2, mem_manager_type> >::kernel_1a res_pipe; matrix_type& result
)
private:
void thread()
{
job<trainer_type> j;
matrix<scalar_type, 1, 2, mem_manager_type> temp_res;
while (job_pipe.dequeue(j))
{ {
try try
{ {
temp_res = test_binary_decision_function(j.trainer.train(j.x_train, j.y_train), j.x_test, j.y_test); result = test_binary_decision_function(j.trainer.train(j.x_train, j.y_train), j.x_test, j.y_test);
} }
catch (invalid_svm_nu_error&) catch (invalid_svm_nu_error&)
{ {
// If this is a svm_nu_trainer then we might get this exception if the nu is // If this is a svm_nu_trainer then we might get this exception if the nu is
// invalid. In this case just return a cross validation score of 0. // invalid. In this case just return a cross validation score of 0.
temp_res = 0; result = 0;
}
res_pipe.enqueue(temp_res);
} }
} }
}; };
...@@ -140,6 +98,10 @@ namespace dlib ...@@ -140,6 +98,10 @@ namespace dlib
); );
task mytask;
thread_pool tp(num_threads);
// count the number of positive and negative examples // count the number of positive and negative examples
long num_pos = 0; long num_pos = 0;
long num_neg = 0; long num_neg = 0;
...@@ -158,18 +120,19 @@ namespace dlib ...@@ -158,18 +120,19 @@ namespace dlib
const long num_neg_train_samples = num_neg - num_neg_test_samples; const long num_neg_train_samples = num_neg - num_neg_test_samples;
typename trainer_type::trained_function_type d;
long pos_idx = 0; long pos_idx = 0;
long neg_idx = 0; long neg_idx = 0;
job<trainer_type> j; std::vector<future<job<trainer_type> > > jobs(folds);
a_thread<trainer_type> threads(num_threads); std::vector<future<matrix<scalar_type, 1, 2, mem_manager_type> > > results(folds);
for (long i = 0; i < folds; ++i) for (long i = 0; i < folds; ++i)
{ {
job<trainer_type>& j = jobs[i].get();
j.x_test.set_size (num_pos_test_samples + num_neg_test_samples); j.x_test.set_size (num_pos_test_samples + num_neg_test_samples);
j.y_test.set_size (num_pos_test_samples + num_neg_test_samples); j.y_test.set_size (num_pos_test_samples + num_neg_test_samples);
j.x_train.set_size(num_pos_train_samples + num_neg_train_samples); j.x_train.set_size(num_pos_train_samples + num_neg_train_samples);
...@@ -232,21 +195,18 @@ namespace dlib ...@@ -232,21 +195,18 @@ namespace dlib
train_neg_idx = (train_neg_idx+1)%x.nr(); train_neg_idx = (train_neg_idx+1)%x.nr();
} }
// add this job to the job pipe so that the threads // finally spawn a task to process this job
// will process it tp.add_task(mytask, jobs[i], results[i]);
threads.job_pipe.enqueue(j);
} // for (long i = 0; i < folds; ++i) } // for (long i = 0; i < folds; ++i)
matrix<scalar_type, 1, 2, mem_manager_type> res; matrix<scalar_type, 1, 2, mem_manager_type> res;
matrix<scalar_type, 1, 2, mem_manager_type> temp_res;
set_all_elements(res,0); set_all_elements(res,0);
// now wait for the threads to finish // now compute the total results
for (long i = 0; i < folds; ++i) for (long i = 0; i < folds; ++i)
{ {
threads.res_pipe.dequeue(temp_res); res += results[i].get();
res += temp_res;
} }
return res/(scalar_type)folds; return res/(scalar_type)folds;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment