Commit af20cb9f authored by Davis King's avatar Davis King

Fixed a bug in the cross_validate_trainer_threaded() function. It could deadlock if

more than about 10 folds were requested.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403137
parent 2e66a59d
......@@ -39,68 +39,26 @@ namespace dlib
scalar_vector_type y_test, y_train;
};
template <typename trainer_type>
void swap(
job<trainer_type>& a,
job<trainer_type>& b
)
{
exchange(a.trainer, b.trainer);
exchange(a.x_test, b.x_test);
exchange(a.y_test, b.y_test);
exchange(a.x_train, b.x_train);
exchange(a.y_train, b.y_train);
}
template <typename trainer_type>
class a_thread : multithreaded_object
struct task
{
public:
typedef typename trainer_type::scalar_type scalar_type;
typedef typename trainer_type::sample_type sample_type;
typedef typename trainer_type::mem_manager_type mem_manager_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
explicit a_thread( long num_threads) : job_pipe(1), res_pipe(3)
template <
typename trainer_type,
typename matrix_type
>
void operator()(
const job<trainer_type>& j,
matrix_type& result
)
{
for (long i = 0; i < num_threads; ++i)
try
{
register_thread(*this, &a_thread::thread);
result = test_binary_decision_function(j.trainer.train(j.x_train, j.y_train), j.x_test, j.y_test);
}
start();
}
~a_thread()
{
// disable the job_pipe so that the threads will unblock and terminate
job_pipe.disable();
wait();
}
typename pipe<job<trainer_type> > ::kernel_1a job_pipe;
typename pipe<matrix<scalar_type, 1, 2, mem_manager_type> >::kernel_1a res_pipe;
private:
void thread()
{
job<trainer_type> j;
matrix<scalar_type, 1, 2, mem_manager_type> temp_res;
while (job_pipe.dequeue(j))
catch (invalid_svm_nu_error&)
{
try
{
temp_res = test_binary_decision_function(j.trainer.train(j.x_train, j.y_train), j.x_test, j.y_test);
}
catch (invalid_svm_nu_error&)
{
// If this is a svm_nu_trainer then we might get this exception if the nu is
// invalid. In this case just return a cross validation score of 0.
temp_res = 0;
}
res_pipe.enqueue(temp_res);
// If this is a svm_nu_trainer then we might get this exception if the nu is
// invalid. In this case just return a cross validation score of 0.
result = 0;
}
}
};
......@@ -140,6 +98,10 @@ namespace dlib
);
task mytask;
thread_pool tp(num_threads);
// count the number of positive and negative examples
long num_pos = 0;
long num_neg = 0;
......@@ -158,18 +120,19 @@ namespace dlib
const long num_neg_train_samples = num_neg - num_neg_test_samples;
typename trainer_type::trained_function_type d;
long pos_idx = 0;
long neg_idx = 0;
job<trainer_type> j;
a_thread<trainer_type> threads(num_threads);
std::vector<future<job<trainer_type> > > jobs(folds);
std::vector<future<matrix<scalar_type, 1, 2, mem_manager_type> > > results(folds);
for (long i = 0; i < folds; ++i)
{
job<trainer_type>& j = jobs[i].get();
j.x_test.set_size (num_pos_test_samples + num_neg_test_samples);
j.y_test.set_size (num_pos_test_samples + num_neg_test_samples);
j.x_train.set_size(num_pos_train_samples + num_neg_train_samples);
......@@ -232,21 +195,18 @@ namespace dlib
train_neg_idx = (train_neg_idx+1)%x.nr();
}
// add this job to the job pipe so that the threads
// will process it
threads.job_pipe.enqueue(j);
// finally spawn a task to process this job
tp.add_task(mytask, jobs[i], results[i]);
} // for (long i = 0; i < folds; ++i)
matrix<scalar_type, 1, 2, mem_manager_type> res;
matrix<scalar_type, 1, 2, mem_manager_type> temp_res;
set_all_elements(res,0);
// now wait for the threads to finish
// now compute the total results
for (long i = 0; i < folds; ++i)
{
threads.res_pipe.dequeue(temp_res);
res += temp_res;
res += results[i].get();
}
return res/(scalar_type)folds;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment