Commit 68412221 authored by Davis King's avatar Davis King

Improved the dnn_trainer. In particular, it no longer makes a copy of the

network (which would needlessly double VRAM usage).  I also added a
set_synchronization_file() method so you can tell it to automatically
synchronize itself to disk every so often during training.  This makes resuming
an interrupted training session trivially easy.
parent 4189386d
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include "solvers.h" #include "solvers.h"
#include "../statistics.h" #include "../statistics.h"
#include <chrono> #include <chrono>
#include <fstream>
#include <sstream>
#include "../serialize.h" #include "../serialize.h"
#include "../pipe.h" #include "../pipe.h"
...@@ -15,6 +17,7 @@ ...@@ -15,6 +17,7 @@
#include "cuda_dlib.h" #include "cuda_dlib.h"
#include "../statistics/running_gradient.h" #include "../statistics/running_gradient.h"
#include <atomic> #include <atomic>
#include <cstdio>
namespace dlib namespace dlib
{ {
...@@ -34,22 +37,20 @@ namespace dlib ...@@ -34,22 +37,20 @@ namespace dlib
typedef typename net_type::label_type label_type; typedef typename net_type::label_type label_type;
typedef typename net_type::input_type input_type; typedef typename net_type::input_type input_type;
const static size_t num_layers = net_type::num_layers;
dnn_trainer( dnn_trainer() = delete;
) : job_pipe(0), solvers(net_type::num_layers) dnn_trainer(const dnn_trainer&) = delete;
{
init();
}
explicit dnn_trainer(const net_type& net_) : job_pipe(0), net(net_), solvers(net_type::num_layers) explicit dnn_trainer(net_type& net_) : job_pipe(0), net(net_), solvers(num_layers)
{ {
init(); init();
} }
dnn_trainer( dnn_trainer(
const net_type& net_, net_type& net_,
const solver_type& solver_ const solver_type& solver_
) : job_pipe(0), net(net_), solvers(net_type::num_layers, solver_) ) : job_pipe(0), net(net_), solvers(num_layers, solver_)
{ {
init(); init();
} }
...@@ -62,27 +63,19 @@ namespace dlib ...@@ -62,27 +63,19 @@ namespace dlib
wait(); wait();
} }
const net_type& get_net ( net_type& get_net (
) const ) const
{ {
wait_for_thread_to_pause(); wait_for_thread_to_pause();
return net; return net;
} }
void set_net (
const net_type& net_
)
{
wait_for_thread_to_pause();
return net = net_;
}
void set_solver ( void set_solver (
const solver_type& solver_ const solver_type& solver_
) )
{ {
wait_for_thread_to_pause(); wait_for_thread_to_pause();
solvers = std::vector<solver_type>(net_type::num_layers, solver_); solvers = std::vector<solver_type>(num_layers, solver_);
} }
unsigned long get_mini_batch_size ( unsigned long get_mini_batch_size (
...@@ -140,6 +133,7 @@ namespace dlib ...@@ -140,6 +133,7 @@ namespace dlib
const std::vector<label_type>& labels const std::vector<label_type>& labels
) )
{ {
sync_to_disk();
job.labels = labels; job.labels = labels;
net.to_tensor(data.begin(), data.end(), job.t); net.to_tensor(data.begin(), data.end(), job.t);
job_pipe.enqueue(job); job_pipe.enqueue(job);
...@@ -149,32 +143,39 @@ namespace dlib ...@@ -149,32 +143,39 @@ namespace dlib
const std::vector<input_type>& data const std::vector<input_type>& data
) )
{ {
sync_to_disk();
net.to_tensor(data.begin(), data.end(), job.t); net.to_tensor(data.begin(), data.end(), job.t);
job_pipe.enqueue(job); job_pipe.enqueue(job);
} }
const net_type& train ( void train (
const std::vector<input_type>& data, const std::vector<input_type>& data,
const std::vector<label_type>& labels const std::vector<label_type>& labels
) )
{ {
DLIB_CASSERT(data.size() == labels.size() && data.size() > 0, ""); DLIB_CASSERT(data.size() == labels.size() && data.size() > 0, "");
for (unsigned long epoch_iteration = 0; bool updated_the_network = false;
// The reason these two loops don't initialize their counter variables but
// instead use class members is so we can include the state of the loops in the
// stuff written by sync_to_disk()
for (;
epoch_iteration < max_num_epochs && step_size >= min_step_size; epoch_iteration < max_num_epochs && step_size >= min_step_size;
++epoch_iteration) ++epoch_iteration)
{ {
using namespace std::chrono; using namespace std::chrono;
auto last_time = system_clock::now(); auto last_time = system_clock::now();
clear_average_loss(); clear_average_loss();
for (size_t i = 0; i < data.size() && step_size >= min_step_size; i += mini_batch_size) for (; epoch_pos < data.size() && step_size >= min_step_size; epoch_pos += mini_batch_size)
{ {
net.to_tensor(data.begin()+i, sync_to_disk();
data.begin()+std::min(i+mini_batch_size,data.size()), net.to_tensor(data.begin()+epoch_pos,
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()),
job.t); job.t);
job.labels.assign(labels.begin()+i, job.labels.assign(labels.begin()+epoch_pos,
labels.begin()+std::min(i+mini_batch_size,data.size())); labels.begin()+std::min(epoch_pos+mini_batch_size,data.size()));
job_pipe.enqueue(job); job_pipe.enqueue(job);
updated_the_network = true;
if (verbose) if (verbose)
...@@ -183,14 +184,16 @@ namespace dlib ...@@ -183,14 +184,16 @@ namespace dlib
if (now_time-last_time > seconds(20)) if (now_time-last_time > seconds(20))
{ {
last_time = now_time; last_time = now_time;
auto iter = epoch_iteration + i/(double)data.size(); auto iter = epoch_iteration + epoch_pos/(double)data.size();
std::cout << "epoch: " << rpad(cast_to_string(iter),epoch_string_pad) << " " std::cout << "epoch: " << rpad(cast_to_string(iter),epoch_string_pad) << " "
<< "step size: " << rpad(cast_to_string(step_size),ss_string_pad) << " " << "step size: " << rpad(cast_to_string(step_size),ss_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad)
<< std::endl; << std::endl;
} }
} }
} }
epoch_pos = 0;
if (verbose) if (verbose)
{ {
...@@ -202,10 +205,12 @@ namespace dlib ...@@ -202,10 +205,12 @@ namespace dlib
<< std::endl; << std::endl;
} }
} }
return get_net(); wait_for_thread_to_pause();
// if we modified the network at all then be sure to sync the final result.
sync_to_disk(updated_the_network);
} }
const net_type& train ( void train (
const std::vector<input_type>& data const std::vector<input_type>& data
) )
{ {
...@@ -215,19 +220,25 @@ namespace dlib ...@@ -215,19 +220,25 @@ namespace dlib
static_assert(has_unsupervised_loss, static_assert(has_unsupervised_loss,
"You can only call this version of train() when using an unsupervised loss."); "You can only call this version of train() when using an unsupervised loss.");
for (unsigned long epoch_iteration = 0; bool updated_the_network = false;
// The reason these two loops don't initialize their counter variables but
// instead use class members is so we can include the state of the loops in the
// stuff written by sync_to_disk()
for (;
epoch_iteration < max_num_epochs && step_size >= min_step_size; epoch_iteration < max_num_epochs && step_size >= min_step_size;
++epoch_iteration) ++epoch_iteration)
{ {
using namespace std::chrono; using namespace std::chrono;
auto last_time = system_clock::now(); auto last_time = system_clock::now();
clear_average_loss(); clear_average_loss();
for (size_t i = 0; i < data.size() && step_size >= min_step_size; i += mini_batch_size) for (; epoch_pos < data.size() && step_size >= min_step_size; epoch_pos += mini_batch_size)
{ {
net.to_tensor(data.begin()+i, sync_to_disk();
data.begin()+std::min(i+mini_batch_size,data.size()), net.to_tensor(data.begin()+epoch_pos,
data.begin()+std::min(epoch_pos+mini_batch_size,data.size()),
job.t); job.t);
job_pipe.enqueue(job); job_pipe.enqueue(job);
updated_the_network = true;
if (verbose) if (verbose)
...@@ -236,7 +247,7 @@ namespace dlib ...@@ -236,7 +247,7 @@ namespace dlib
if (now_time-last_time > seconds(20)) if (now_time-last_time > seconds(20))
{ {
last_time = now_time; last_time = now_time;
auto iter = epoch_iteration + i/(double)data.size(); auto iter = epoch_iteration + epoch_pos/(double)data.size();
std::cout << "epoch: " << rpad(cast_to_string(iter),epoch_string_pad) << " " std::cout << "epoch: " << rpad(cast_to_string(iter),epoch_string_pad) << " "
<< "step size: " << rpad(cast_to_string(step_size),ss_string_pad) << " " << "step size: " << rpad(cast_to_string(step_size),ss_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad)
...@@ -244,6 +255,7 @@ namespace dlib ...@@ -244,6 +255,7 @@ namespace dlib
} }
} }
} }
epoch_pos = 0;
if (verbose) if (verbose)
{ {
...@@ -255,48 +267,34 @@ namespace dlib ...@@ -255,48 +267,34 @@ namespace dlib
<< std::endl; << std::endl;
} }
} }
return get_net(); wait_for_thread_to_pause();
} // if we modified the network at all then be sure to sync the final result.
sync_to_disk(updated_the_network);
friend void serialize(const dnn_trainer& item, std::ostream& out)
{
item.wait_for_thread_to_pause();
int version = 3;
serialize(version, out);
serialize(item.rs, out);
serialize(item.rg, out);
serialize(item.max_num_epochs, out);
serialize(item.mini_batch_size, out);
serialize(item.verbose, out);
serialize(item.net, out);
serialize(item.solvers, out);
serialize(item.step_size.load(), out);
serialize(item.min_step_size, out);
serialize(item.iter_between_step_size_adjust.load(), out);
serialize(item.step_size_shrink.load(), out);
} }
friend void deserialize(dnn_trainer& item, std::istream& in) void set_synchronization_file (
const std::string& filename,
std::chrono::seconds time_between_syncs_ = std::chrono::minutes(15)
)
{ {
item.wait_for_thread_to_pause(); last_sync_time = std::chrono::system_clock::now();
int version = 0; sync_filename = filename;
deserialize(version, in); time_between_syncs = time_between_syncs_;
if (version != 3)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer."); // check if the sync file already exists, if it does we should load it. We
// first check for a .tmp version since that would be the newest if it existed.
double temp; // If it doesn't exist we check the canonical sync file.
deserialize(item.rs, in); std::ifstream fin(filename+".tmp", std::ios::binary);
deserialize(item.rg, in); if (fin)
deserialize(item.max_num_epochs, in); {
deserialize(item.mini_batch_size, in); deserialize(*this, fin);
deserialize(item.verbose, in); }
deserialize(item.net, in); else
deserialize(item.solvers, in); {
deserialize(temp, in); item.step_size = temp; std::ifstream fin(filename, std::ios::binary);
deserialize(item.min_step_size, in); if (fin)
deserialize(temp, in); item.iter_between_step_size_adjust = temp; deserialize(*this, fin);
deserialize(temp, in); item.step_size_shrink = temp; }
} }
double get_average_loss ( double get_average_loss (
...@@ -442,9 +440,102 @@ namespace dlib ...@@ -442,9 +440,102 @@ namespace dlib
min_step_size = 1e-4; min_step_size = 1e-4;
iter_between_step_size_adjust = 2000; iter_between_step_size_adjust = 2000;
step_size_shrink = 0.1; step_size_shrink = 0.1;
epoch_iteration = 0;
epoch_pos = 0;
start(); start();
} }
// serialize and deserialize are private because we hold net by reference so
// allowing someone to serialize this training object is weird and will likely
// result in user errors. However, we use these functions as part of the automatic
// sync code in this object.
friend void serialize(const dnn_trainer& item, std::ostream& out)
{
item.wait_for_thread_to_pause();
int version = 3;
serialize(version, out);
size_t nl = dnn_trainer::num_layers;
serialize(nl, out);
serialize(item.rs, out);
serialize(item.rg, out);
serialize(item.max_num_epochs, out);
serialize(item.mini_batch_size, out);
serialize(item.verbose, out);
serialize(item.net, out);
serialize(item.solvers, out);
serialize(item.step_size.load(), out);
serialize(item.min_step_size, out);
serialize(item.iter_between_step_size_adjust.load(), out);
serialize(item.step_size_shrink.load(), out);
serialize(item.epoch_iteration, out);
serialize(item.epoch_pos, out);
}
friend void deserialize(dnn_trainer& item, std::istream& in)
{
item.wait_for_thread_to_pause();
int version = 0;
deserialize(version, in);
if (version != 3)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
size_t num_layers = 0;
deserialize(num_layers, in);
if (num_layers != dnn_trainer::num_layers)
{
std::ostringstream sout;
sout << "Error deserializing dlib::dnn_trainer. The saved sync file is for a network with " << std::endl;
sout << "a different number of layers. We expected the number of layers to be " << dnn_trainer::num_layers << " but" << std::endl;
sout << "instead the file contains " << num_layers << " layers." << std::endl;
throw serialization_error(sout.str());
}
double dtemp; long ltemp;
deserialize(item.rs, in);
deserialize(item.rg, in);
deserialize(item.max_num_epochs, in);
deserialize(item.mini_batch_size, in);
deserialize(item.verbose, in);
deserialize(item.net, in);
deserialize(item.solvers, in);
deserialize(dtemp, in); item.step_size = dtemp;
deserialize(item.min_step_size, in);
deserialize(ltemp, in); item.iter_between_step_size_adjust = ltemp;
deserialize(dtemp, in); item.step_size_shrink = dtemp;
deserialize(item.epoch_iteration, in);
deserialize(item.epoch_pos, in);
}
void sync_to_disk (
bool do_it_now = false
)
{
// If the sync file isn't set then don't do anything.
if (sync_filename.size() == 0)
return;
// Only sync if it has been long enough since the last sync or we are being
// explicitly forced to do it.
if (std::chrono::system_clock::now() - last_sync_time > time_between_syncs ||
do_it_now)
{
// save our state to a temp file
std::string tempfile = sync_filename + ".tmp";
std::ofstream fout(tempfile, std::ios::binary);
serialize(*this, fout);
fout.close();
// Now that we know the state is safely saved to disk, delete the old sync
// file and move the .tmp file to it.
std::remove(sync_filename.c_str());
std::rename(tempfile.c_str(), sync_filename.c_str());
last_sync_time = std::chrono::system_clock::now();
if (verbose)
std::cout << "Saved state to " << sync_filename << std::endl;
}
}
dlib::pipe<job_t> job_pipe; dlib::pipe<job_t> job_pipe;
running_stats<double> rs; running_stats<double> rs;
...@@ -453,12 +544,17 @@ namespace dlib ...@@ -453,12 +544,17 @@ namespace dlib
size_t mini_batch_size; size_t mini_batch_size;
bool verbose; bool verbose;
int cuda_device_id; int cuda_device_id;
net_type net; net_type& net;
std::vector<solver_type> solvers; std::vector<solver_type> solvers;
std::atomic<double> step_size; std::atomic<double> step_size;
double min_step_size; double min_step_size;
std::atomic<long> iter_between_step_size_adjust; std::atomic<long> iter_between_step_size_adjust;
std::atomic<double> step_size_shrink; std::atomic<double> step_size_shrink;
std::chrono::time_point<std::chrono::system_clock> last_sync_time;
std::string sync_filename;
std::chrono::seconds time_between_syncs;
unsigned long epoch_iteration;
unsigned long epoch_pos;
// The job object is not logically part of the state of this object. It is here // The job object is not logically part of the state of this object. It is here
// only to avoid reallocating it over and over. // only to avoid reallocating it over and over.
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "core_abstract.h" #include "core_abstract.h"
#include "solvers_abstract.h" #include "solvers_abstract.h"
#include <vector> #include <vector>
#include <chrono>
namespace dlib namespace dlib
...@@ -39,43 +40,21 @@ namespace dlib ...@@ -39,43 +40,21 @@ namespace dlib
typedef typename net_type::label_type label_type; typedef typename net_type::label_type label_type;
typedef typename net_type::input_type input_type; typedef typename net_type::input_type input_type;
const static size_t num_layers = net_type::num_layers;
dnn_trainer( dnn_trainer() = delete;
); dnn_trainer(const dnn_trainer&) = delete;
/*!
ensures
- #get_net() == a default initialized net_type object.
- #get_solvers() == a set of default initialized solvers.
- #get_max_num_epochs() == 10000
- #get_mini_batch_size() == 128
- #get_step_size() == 1
- #get_min_step_size() == 1e-4
- #get_iterations_between_step_size_adjust() == 2000
- #get_step_size_shrink() == 0.1
!*/
explicit dnn_trainer(
const net_type& net
);
/*!
ensures
- #get_net() == net
- #get_solvers() == a set of default initialized solvers.
- #get_max_num_epochs() == 10000
- #get_mini_batch_size() == 128
- #get_step_size() == 1
- #get_min_step_size() == 1e-4
- #get_iterations_between_step_size_adjust() == 2000
- #get_step_size_shrink() == 0.1
!*/
dnn_trainer( dnn_trainer(
const net_type& net, net_type& net,
const solver_type& solver const solver_type& solver = solver_type()
); );
/*! /*!
ensures ensures
- #get_net() == net - &#get_net() == &net
(i.e. The dnn_trainer holds a reference to net, it does not copy it.
Therefore, you must ensure net has a lifetime at least as long as the
dnn_trainer).
- #get_solvers() == a set of solvers that are all initialized with the - #get_solvers() == a set of solvers that are all initialized with the
provided solver instance. provided solver instance.
- #get_max_num_epochs() == 10000 - #get_max_num_epochs() == 10000
...@@ -86,20 +65,15 @@ namespace dlib ...@@ -86,20 +65,15 @@ namespace dlib
- #get_step_size_shrink() == 0.1 - #get_step_size_shrink() == 0.1
!*/ !*/
const net_type& get_net ( net_type& get_net (
) const; ) const;
/*! /*!
ensures ensures
- returns the neural network object in this trainer. This is the network - returns the neural network object used by this trainer. This is the
that is optimized when you call train(). network that is optimized when you call train() or train_one_step().
!*/ Recall that the dnn_trainer doesn't contain the net_type object but
simply holds a reference to an external network which was provided to the
void set_net ( dnn_trainer's constructor.
const net_type& net
);
/*!
ensures
- #get_net() == net
!*/ !*/
void set_solver ( void set_solver (
...@@ -275,7 +249,23 @@ namespace dlib ...@@ -275,7 +249,23 @@ namespace dlib
- This object will not print anything to standard out - This object will not print anything to standard out
!*/ !*/
const net_type& train ( void set_synchronization_file (
const std::string& filename,
std::chrono::seconds time_between_syncs = std::chrono::minutes(15)
);
/*!
ensures
- While training is running, either via train() or repeated calls to
train_one_step(), this object will save its entire state, including the
state of get_net(), to disk in the file named filename every
time_between_syncs seconds.
- if the filename file already exists then the state of this trainer will
be loaded from that file by this call to set_synchronization_file().
This allows you to resume a training session which was previously
interrupted.
!*/
void train (
const std::vector<input_type>& data, const std::vector<input_type>& data,
const std::vector<label_type>& labels const std::vector<label_type>& labels
); );
...@@ -292,22 +282,17 @@ namespace dlib ...@@ -292,22 +282,17 @@ namespace dlib
get_max_num_epochs() training epochs have been executes. get_max_num_epochs() training epochs have been executes.
- Each layer in the network will be optimized by its corresponding solver - Each layer in the network will be optimized by its corresponding solver
in get_solvers(). in get_solvers().
- returns #get_net()
(i.e. the trained network can also be accessed by calling get_net() after
train() finishes executing)
- Each call to train DOES NOT reinitialize the state of get_net() or - Each call to train DOES NOT reinitialize the state of get_net() or
get_solvers(). That is, the state of the solvers and network contained get_solvers(). That is, the existing state of the solvers and network is
inside this trainer is the starting point for the optimization each time the starting point for the optimization each time train() is called. In
train() is called. For example, calling train() 1 time and having it particular, if you use the set_synchronization_file() method you can
execute 100 epochs of training is equivalent to calling train() 10 times resume an interrupted train() call by simply calling train() again and it
and having it execute 10 epochs of training during each call. This also will pick up from the last synchronization point.
means you can serialize a trainer to disk and then, at a later date,
deserialize it and resume training your network where you left off.
- You can obtain the average loss value during the final training epoch by - You can obtain the average loss value during the final training epoch by
calling get_average_loss(). calling get_average_loss().
!*/ !*/
const net_type& train ( void train (
const std::vector<input_type>& data const std::vector<input_type>& data
); );
/*! /*!
...@@ -322,17 +307,12 @@ namespace dlib ...@@ -322,17 +307,12 @@ namespace dlib
get_max_num_epochs() training epochs have been executes. get_max_num_epochs() training epochs have been executes.
- Each layer in the network will be optimized by its corresponding solver - Each layer in the network will be optimized by its corresponding solver
in get_solvers(). in get_solvers().
- returns #get_net()
(i.e. the trained network can also be accessed by calling get_net() after
train() finishes executing)
- Each call to train DOES NOT reinitialize the state of get_net() or - Each call to train DOES NOT reinitialize the state of get_net() or
get_solvers(). That is, the state of the solvers and network contained get_solvers(). That is, the existing state of the solvers and network is
inside this trainer is the starting point for the optimization each time the starting point for the optimization each time train() is called. In
train() is called. For example, calling train() 1 time and having it particular, if you use the set_synchronization_file() method you can
execute 100 epochs of training is equivalent to calling train() 10 times resume an interrupted train() call by simply calling train() again and it
and having it execute 10 epochs of training during each call. This also will pick up from the last synchronization point.
means you can serialize a trainer to disk and then, at a later date,
deserialize it and resume training your network where you left off.
- You can obtain the average loss value during the final training epoch by - You can obtain the average loss value during the final training epoch by
calling get_average_loss(). calling get_average_loss().
!*/ !*/
...@@ -398,14 +378,6 @@ namespace dlib ...@@ -398,14 +378,6 @@ namespace dlib
}; };
template <typename T, typename U>
void serialize(const dnn_trainer<T,U>& item, std::ostream& out);
template <typename T, typename U>
void deserialize(dnn_trainer<T,U>& item, std::istream& in);
/*!
provides serialization support
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment