Commit 97474376 authored by Davis King's avatar Davis King

Changed code to avoid recreating thread_local cuda context objects.

parent e55afabd
......@@ -535,7 +535,15 @@ namespace dlib
std::vector<tensor*> reference_params;
visit_layer_parameters(devices[0]->net, [&](size_t, tensor& t) { reference_params.push_back(&t); });
thread_pool tp(devices.size());
// We make separate thread pools with just one thread in them because we want
// to make sure each device is always executed on the same thread. We care
// about this because there are thread_local context variables for some cuda
// components and they get regenerated when the current cuda device changes.
// Recreating them over and over is somewhat expensive so we want to avoid
// that.
std::vector<std::shared_ptr<thread_pool>> tp;
for (size_t i = 0; i < devices.size(); ++i)
tp.push_back(std::make_shared<thread_pool>(1));
size_t iteration = 0;
......@@ -546,7 +554,7 @@ namespace dlib
// right version for unsupervised or supervised training based on the type
// of label_type.
for (size_t i = 0; i < devices.size(); ++i)
tp.add_task_by_value([&,i](double& loss){ loss = compute_parameter_gradients(i, next_job, pick_which_run_update); }, losses[i]);
tp[i]->add_task_by_value([&,i](double& loss){ loss = compute_parameter_gradients(i, next_job, pick_which_run_update); }, losses[i]);
// aggregate loss values from all the network computations.
double theloss = 0;
for (auto&& loss : losses)
......@@ -597,9 +605,10 @@ namespace dlib
// Now apply all the updates to each device.
for (size_t i = 0; i < devices.size(); ++i)
tp.add_task_by_value([&,i](){ if (next_job.have_data[i]) update_parameters(i); });
tp[i]->add_task_by_value([&,i](){ if (next_job.have_data[i]) update_parameters(i); });
// and wait for the updates to all happen.
tp.wait_for_all_tasks();
for (size_t i = 0; i < devices.size(); ++i)
tp[i]->wait_for_all_tasks();
// Evey now and then force all the parameters to be the same just to make
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment