Commit ca11d108 authored by Davis King's avatar Davis King

Added multi-gpu support to the dnn_trainer

parent b9fd9564
......@@ -18,12 +18,38 @@
#include "../statistics/running_gradient.h"
#include <atomic>
#include <cstdio>
#include <set>
#include <future>
namespace dlib
{
// ----------------------------------------------------------------------------------------
namespace impl
{
template <typename label_type>
struct dnn_job_t
{
dnn_job_t() = default;
dnn_job_t(const dnn_job_t&) = delete;
dnn_job_t& operator=(const dnn_job_t&) = delete;
std::vector<std::vector<label_type>> labels;
std::vector<resizable_tensor> t;
std::vector<int> have_data; // have_data[i] is true if there is data in labels[i] and t[i].
};
template <typename label_type>
void swap(dnn_job_t<label_type>& a, dnn_job_t<label_type>& b)
{
a.labels.swap(b.labels);
a.t.swap(b.t);
a.have_data.swap(b.have_data);
}
}
template <
typename net_type,
typename solver_type = sgd
......@@ -38,20 +64,59 @@ namespace dlib
typedef typename net_type::label_type label_type;
typedef typename net_type::input_type input_type;
const static size_t num_computational_layers = net_type::num_computational_layers;
const static size_t num_layers = net_type::num_layers;
private:
typedef impl::dnn_job_t<label_type> job_t;
public:
dnn_trainer() = delete;
dnn_trainer(const dnn_trainer&) = delete;
dnn_trainer& operator=(const dnn_trainer&) = delete;
explicit dnn_trainer(net_type& net_) : job_pipe(0), net(net_), solvers(num_computational_layers)
explicit dnn_trainer(net_type& net_) : job_pipe(0), net(net_)
{
solver_type default_solver;
devices.push_back(std::make_shared<device_data>(dlib::cuda::get_device(), net, default_solver));
init();
}
dnn_trainer(
net_type& net_,
const solver_type& solver_
) : job_pipe(0), net(net_), solvers(num_computational_layers, solver_)
) : job_pipe(0), net(net_)
{
devices.push_back(std::make_shared<device_data>(dlib::cuda::get_device(), net, solver_));
init();
}
dnn_trainer(
net_type& net_,
const solver_type& solver_,
const std::vector<int>& cuda_extra_devices
) : job_pipe(0), net(net_)
{
devices.push_back(std::make_shared<device_data>(dlib::cuda::get_device(), net, solver_));
const int total_devices = dlib::cuda::get_num_devices();
// Make device contexts for the extra device ids but be careful to avoid any
// duplicate ids.
std::set<int> temp(cuda_extra_devices.begin(), cuda_extra_devices.end());
temp.erase(devices[0]->device_id);
for (auto id : temp)
{
DLIB_CASSERT(0 <= id && id < total_devices, "Invalid CUDA device id given to dnn_trainer.");
// Switch to this device so that any tensor objects that get allocated when
// we create the device context happen on this device.
dlib::cuda::set_device(id);
devices.push_back(std::make_shared<device_data>(id, net, solver_, clone_net()));
}
// Set the current device back to what it was before this constructor was
// called.
dlib::cuda::set_device(devices[0]->device_id);
init();
}
......@@ -70,13 +135,6 @@ namespace dlib
return net;
}
void set_solver (
const solver_type& solver_
)
{
wait_for_thread_to_pause();
solvers = std::vector<solver_type>(num_computational_layers, solver_);
}
unsigned long get_mini_batch_size (
) const { return mini_batch_size; }
......@@ -117,22 +175,16 @@ namespace dlib
) const
{
wait_for_thread_to_pause();
return solvers;
return devices[0]->solvers;
}
std::vector<solver_type>& get_solvers (
)
{
wait_for_thread_to_pause();
return solvers;
}
void train_one_step (
const std::vector<input_type>& data,
const std::vector<label_type>& labels
)
{
DLIB_CASSERT(data.size() == labels.size() && data.size() > 0, "");
if (verbose)
{
using namespace std::chrono;
......@@ -149,9 +201,8 @@ namespace dlib
}
}
sync_to_disk();
job.labels = labels;
net.to_tensor(data.begin(), data.end(), job.t);
job_pipe.enqueue(job);
send_job(data.begin(), data.end(), labels.begin());
++train_one_step_calls;
}
......@@ -159,6 +210,7 @@ namespace dlib
const std::vector<input_type>& data
)
{
DLIB_CASSERT(data.size() > 0, "");
if (verbose)
{
using namespace std::chrono;
......@@ -175,8 +227,7 @@ namespace dlib
}
}
sync_to_disk();
net.to_tensor(data.begin(), data.end(), job.t);
job_pipe.enqueue(job);
send_job(data.begin(), data.end());
++train_one_step_calls;
}
......@@ -216,12 +267,9 @@ namespace dlib
}
sync_to_disk();
net.to_tensor(data.begin()+epoch_pos,
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()),
job.t);
job.labels.assign(labels.begin()+epoch_pos,
labels.begin()+std::min(epoch_pos+mini_batch_size,data.size()));
job_pipe.enqueue(job);
send_job(data.begin()+epoch_pos,
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()),
labels.begin()+epoch_pos);
updated_the_network = true;
}
epoch_pos = 0;
......@@ -281,10 +329,8 @@ namespace dlib
}
sync_to_disk();
net.to_tensor(data.begin()+epoch_pos,
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()),
job.t);
job_pipe.enqueue(job);
send_job(data.begin()+epoch_pos,
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()));
updated_the_network = true;
}
epoch_pos = 0;
......@@ -393,11 +439,6 @@ namespace dlib
}
private:
struct job_t
{
std::vector<label_type> labels;
resizable_tensor t;
};
void record_loss(double loss)
{
......@@ -416,34 +457,98 @@ namespace dlib
}
template <typename T>
void run_update(job_t& next_job, const T&)
double compute_parameter_gradients(size_t device, job_t& next_job, const T&)
{
double loss = net.compute_parameter_gradients(next_job.t, next_job.labels.begin());
net.update_parameters(make_sstack(solvers),step_size);
record_loss(loss);
if (next_job.have_data[device])
{
auto&& dev = *devices[device];
dlib::cuda::set_device(dev.device_id);
return dev.net.compute_parameter_gradients(next_job.t[device], next_job.labels[device].begin());
}
else
{
return 0;
}
}
void run_update(job_t& next_job, const no_label_type&)
double compute_parameter_gradients(size_t device, job_t& next_job, const no_label_type&)
{
no_label_type pick_which_run_update;
double loss = net.compute_parameter_gradients(next_job.t);
net.update_parameters(make_sstack(solvers), step_size);
record_loss(loss);
if (next_job.have_data[device])
{
auto&& dev = *devices[device];
dlib::cuda::set_device(dev.device_id);
no_label_type pick_which_run_update;
return dev.net.compute_parameter_gradients(next_job.t[device]);
}
else
{
return 0;
}
}
void update_parameters(size_t device)
{
auto&& dev = *devices[device];
dlib::cuda::set_device(dev.device_id);
dev.net.update_parameters(make_sstack(dev.solvers), step_size);
}
void thread() try
{
// Make sure this thread uses the same cuda device as the thread that created
// the dnn_trainer object.
dlib::cuda::set_device(cuda_device_id);
label_type pick_which_run_update;
job_t next_job;
std::vector<std::future<double>> losses(devices.size());
std::vector<std::future<void>> update_futs(devices.size());
std::vector<matrix<float>> param_buffer(net_type::num_computational_layers);
while(job_pipe.dequeue(next_job))
{
// call net.compute_parameter_gradients() and net.update_parameters() but
// pick the right version for unsupervised or supervised training based on
// the type of label_type.
run_update(next_job, pick_which_run_update);
// Call compute_parameter_gradients() and update_parameters() but pick the
// right version for unsupervised or supervised training based on the type
// of label_type.
for (size_t i = 0; i < devices.size(); ++i)
losses[i] = std::async(std::launch::async,[&,i](){ return compute_parameter_gradients(i, next_job, pick_which_run_update); });
// aggregate loss values from all the network computations.
for (auto&& loss : losses)
record_loss(loss.get());
// Now, if there is more than one active device we need to synchronize the
// gradient updates between devices. So we do that now.
if (devices.size() > 1)
{
for (auto&& p : param_buffer)
p = 0;
// now average all the parameter gradients
for (size_t i = 0; i < devices.size(); ++i)
{
visit_layer_parameters(devices[i]->net, [&param_buffer](size_t j, tensor& t) {
if (t.size() != 0)
param_buffer[j] += mat(t);
});
}
// and then assign the parameter gradients back to all the networks
const float scale = 1.0f/devices.size();
for (size_t i = 0; i < devices.size(); ++i)
{
visit_layer_parameters(devices[i]->net, [scale,&param_buffer](size_t j, tensor& t) {
if (t.size() != 0)
{
t = param_buffer[j]*scale;
t.async_copy_to_device();
}
});
}
}
// Now apply all the updates to each device.
for (size_t i = 0; i < devices.size(); ++i)
update_futs[i] = std::async(std::launch::async, [&,i](){ if (next_job.have_data[i]) update_parameters(i); });
// and wait for the updates to all happen.
for (auto&& f : update_futs)
f.wait();
// If we have been running for a while then check if the loss is still
// dropping. If it isn't then we will reduce the step size. Note that we
......@@ -484,7 +589,6 @@ namespace dlib
max_num_epochs = 10000;
mini_batch_size = 128;
verbose = false;
cuda_device_id = dlib::cuda::get_device();
step_size = 1;
min_step_size = 1e-3;
iter_without_progress_thresh = 2000;
......@@ -504,10 +608,10 @@ namespace dlib
friend void serialize(const dnn_trainer& item, std::ostream& out)
{
item.wait_for_thread_to_pause();
int version = 5;
int version = 6;
serialize(version, out);
size_t nl = dnn_trainer::num_computational_layers;
size_t nl = dnn_trainer::num_layers;
serialize(nl, out);
serialize(item.rs, out);
serialize(item.previous_loss_values, out);
......@@ -515,7 +619,7 @@ namespace dlib
serialize(item.mini_batch_size, out);
serialize(item.verbose, out);
serialize(item.net, out);
serialize(item.solvers, out);
serialize(item.devices[0]->solvers, out);
serialize(item.step_size.load(), out);
serialize(item.min_step_size, out);
serialize(item.iter_without_progress_thresh.load(), out);
......@@ -530,17 +634,17 @@ namespace dlib
item.wait_for_thread_to_pause();
int version = 0;
deserialize(version, in);
if (version != 5)
if (version != 6)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
size_t num_computational_layers = 0;
deserialize(num_computational_layers, in);
if (num_computational_layers != dnn_trainer::num_computational_layers)
size_t num_layers = 0;
deserialize(num_layers, in);
if (num_layers != dnn_trainer::num_layers)
{
std::ostringstream sout;
sout << "Error deserializing dlib::dnn_trainer. The saved sync file is for a network with " << std::endl;
sout << "a different number of layers. We expected the number of computational layers to be " << dnn_trainer::num_computational_layers << " but" << std::endl;
sout << "instead the file contains " << num_computational_layers << " computational layers." << std::endl;
sout << "a different number of layers. We expected the number of layers to be " << dnn_trainer::num_layers << " but" << std::endl;
sout << "instead the file contains " << num_layers << " layers." << std::endl;
throw serialization_error(sout.str());
}
......@@ -551,7 +655,7 @@ namespace dlib
deserialize(item.mini_batch_size, in);
deserialize(item.verbose, in);
deserialize(item.net, in);
deserialize(item.solvers, in);
deserialize(item.devices[0]->solvers, in);
deserialize(dtemp, in); item.step_size = dtemp;
deserialize(item.min_step_size, in);
deserialize(ltemp, in); item.iter_without_progress_thresh = ltemp;
......@@ -560,6 +664,21 @@ namespace dlib
deserialize(item.epoch_iteration, in);
deserialize(item.epoch_pos, in);
deserialize(item.train_one_step_calls, in);
if (item.devices.size() > 1)
{
const auto prev_dev = dlib::cuda::get_device();
// initialize all the other device networks and solver objects
for (size_t i = 1; i < item.devices.size(); ++i)
{
// Switch to this device so that any tensor objects that get allocated when
// we copy this stuff happen on this device.
dlib::cuda::set_device(item.devices[i]->device_id);
item.devices[i]->solvers = item.devices[0]->solvers;
item.devices[i]->net = item.devices[0]->net;
}
dlib::cuda::set_device(prev_dev);
}
}
void sync_to_disk (
bool do_it_now = false
......@@ -594,16 +713,96 @@ namespace dlib
}
struct clone_net{};
// per device state. All the containers have the same number of objects in them.
struct device_data
{
device_data(
int device_id_,
net_type& net_,
const solver_type& solver_
) : device_id(device_id_), net(net_), solvers(num_computational_layers, solver_) {}
device_data(
int device_id_,
net_type& net_,
const solver_type& solver_,
clone_net
) : device_id(device_id_), net_copy(std::make_shared<net_type>(net_)), net(*net_copy), solvers(num_computational_layers, solver_) {}
int device_id;
std::shared_ptr<net_type> net_copy;
net_type& net;
std::vector<solver_type> solvers;
};
template <
typename data_iterator,
typename label_iterator
>
void send_job (
data_iterator dbegin,
data_iterator dend,
label_iterator lbegin
)
{
size_t num = std::distance(dbegin, dend);
size_t devs = devices.size();
job.t.resize(devs);
job.labels.resize(devs);
job.have_data.resize(devs);
// chop the data into devs blocks, each of about block_size elements.
size_t block_size = (num+devs-1)/devs;
const auto prev_dev = dlib::cuda::get_device();
for (size_t i = 0; i < devs; ++i)
{
dlib::cuda::set_device(devices[i]->device_id);
size_t start = i*block_size;
size_t stop = std::min(num, start+block_size);
if (start < stop)
{
devices[i]->net.to_tensor(dbegin+start, dbegin+stop, job.t[i]);
job.labels[i].assign(lbegin+start, lbegin+stop);
job.have_data[i] = true;
}
else
{
job.have_data[i] = false;
}
}
dlib::cuda::set_device(prev_dev);
job_pipe.enqueue(job);
}
template <
typename data_iterator
>
void send_job (
data_iterator dbegin,
data_iterator dend
)
{
typename std::vector<label_type>::iterator nothing;
send_job(dbegin, dend, nothing);
}
std::vector<std::shared_ptr<device_data>> devices;
dlib::pipe<job_t> job_pipe;
job_t job;
running_stats<double> rs;
std::deque<double> previous_loss_values;
unsigned long max_num_epochs;
size_t mini_batch_size;
bool verbose;
int cuda_device_id;
net_type& net;
std::vector<solver_type> solvers;
std::atomic<double> step_size;
double min_step_size;
std::atomic<unsigned long> iter_without_progress_thresh;
......@@ -618,9 +817,8 @@ namespace dlib
unsigned long long train_one_step_calls;
unsigned long gradient_check_budget;
// The job object is not logically part of the state of this object. It is here
// only to avoid reallocating it over and over.
job_t job;
};
// ----------------------------------------------------------------------------------------
......
......@@ -48,12 +48,17 @@ namespace dlib
dnn_trainer() = delete;
dnn_trainer(const dnn_trainer&) = delete;
dnn_trainer& operator=(const dnn_trainer&) = delete;
dnn_trainer(
net_type& net,
const solver_type& solver = solver_type()
const solver_type& solver = solver_type(),
const std::vector<int>& cuda_extra_devices = {}
);
/*!
requires
- for all valid i:
- 0 <= cuda_extra_devices[i] < dlib::cuda::get_num_devices()
ensures
- &#get_net() == &net
(i.e. The dnn_trainer holds a reference to net, it does not copy it.
......@@ -67,6 +72,13 @@ namespace dlib
- #get_min_step_size() == 1e-3
- #get_iterations_without_progress_threshold() == 2000
- #get_step_size_shrink() == 0.1
- if (cuda_extra_devices.size() > 0) then
- This object will use multiple graphics cards to run the learning
algorithms. In particular, it will always use whatever device is
currently selected on the calling thread (the device indicated by
cudaGetDevice()). In addition, you can ask to use additional
devices, which you do by putting their device numbers into
cuda_extra_devices.
!*/
net_type& get_net (
......@@ -82,15 +94,6 @@ namespace dlib
stopped touching the net.
!*/
void set_solver (
const solver_type& solver
);
/*!
ensures
- assigns solver to all the solvers in this object. I.e. solver will be
assigned to each element in get_solvers().
!*/
const std::vector<solver_type>& get_solvers (
) const;
/*!
......@@ -101,22 +104,6 @@ namespace dlib
get_solvers()[1], and so on.
!*/
std::vector<solver_type>& get_solvers (
);
/*!
ensures
- returns the solvers used to optimize each layer of the neural network
get_net(). In particular, the first layer's solver is
get_solvers()[0], the second layer's solver is
get_solvers()[1], and so on.
- It should be noted that you should never change the number of elements in
the vector returned by get_solvers() (i.e. don't do something that changes
get_solvers().size()). It will be set to net_type::num_computational_layers
by this object and you should leave it at that. The non-const version of
get_solvers() is provided only so you can tweak the parameters of a
particular solver.
!*/
unsigned long get_mini_batch_size (
) const;
/*!
......@@ -289,6 +276,7 @@ namespace dlib
/*!
requires
- data.size() == labels.size()
- data.size() > 0
- net_type uses a supervised loss.
i.e. net_type::label_type != no_label_type.
ensures
......@@ -314,6 +302,7 @@ namespace dlib
);
/*!
requires
- data.size() > 0
- net_type uses an unsupervised loss.
i.e. net_type::label_type == no_label_type.
ensures
......@@ -341,6 +330,7 @@ namespace dlib
/*!
requires
- data.size() == labels.size()
- data.size() > 0
- net_type uses a supervised loss.
i.e. net_type::label_type != no_label_type.
ensures
......@@ -363,6 +353,7 @@ namespace dlib
);
/*!
requires
- data.size() > 0
- net_type uses an unsupervised loss.
i.e. net_type::label_type == no_label_type.
ensures
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment