Commit 68412221 authored by Davis King's avatar Davis King

Improved the dnn_trainer. In particular, it no longer makes a copy of the

network (which would needlessly double VRAM usage).  I also added a
set_synchronization_file() method so you can tell it to automatically
synchronize itself to disk every so often during training.  This makes resuming
an interrupted training session trivially easy.
parent 4189386d
This diff is collapsed.
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "core_abstract.h" #include "core_abstract.h"
#include "solvers_abstract.h" #include "solvers_abstract.h"
#include <vector> #include <vector>
#include <chrono>
namespace dlib namespace dlib
...@@ -39,43 +40,21 @@ namespace dlib ...@@ -39,43 +40,21 @@ namespace dlib
typedef typename net_type::label_type label_type; typedef typename net_type::label_type label_type;
typedef typename net_type::input_type input_type; typedef typename net_type::input_type input_type;
const static size_t num_layers = net_type::num_layers;
dnn_trainer( dnn_trainer() = delete;
); dnn_trainer(const dnn_trainer&) = delete;
/*!
ensures
- #get_net() == a default initialized net_type object.
- #get_solvers() == a set of default initialized solvers.
- #get_max_num_epochs() == 10000
- #get_mini_batch_size() == 128
- #get_step_size() == 1
- #get_min_step_size() == 1e-4
- #get_iterations_between_step_size_adjust() == 2000
- #get_step_size_shrink() == 0.1
!*/
explicit dnn_trainer(
const net_type& net
);
/*!
ensures
- #get_net() == net
- #get_solvers() == a set of default initialized solvers.
- #get_max_num_epochs() == 10000
- #get_mini_batch_size() == 128
- #get_step_size() == 1
- #get_min_step_size() == 1e-4
- #get_iterations_between_step_size_adjust() == 2000
- #get_step_size_shrink() == 0.1
!*/
dnn_trainer( dnn_trainer(
const net_type& net, net_type& net,
const solver_type& solver const solver_type& solver = solver_type()
); );
/*! /*!
ensures ensures
- #get_net() == net - &#get_net() == &net
(i.e. The dnn_trainer holds a reference to net, it does not copy it.
Therefore, you must ensure net has a lifetime at least as long as the
dnn_trainer).
- #get_solvers() == a set of solvers that are all initialized with the - #get_solvers() == a set of solvers that are all initialized with the
provided solver instance. provided solver instance.
- #get_max_num_epochs() == 10000 - #get_max_num_epochs() == 10000
...@@ -86,20 +65,15 @@ namespace dlib ...@@ -86,20 +65,15 @@ namespace dlib
- #get_step_size_shrink() == 0.1 - #get_step_size_shrink() == 0.1
!*/ !*/
const net_type& get_net ( net_type& get_net (
) const; ) const;
/*! /*!
ensures ensures
- returns the neural network object in this trainer. This is the network - returns the neural network object used by this trainer. This is the
that is optimized when you call train(). network that is optimized when you call train() or train_one_step().
!*/ Recall that the dnn_trainer doesn't contain the net_type object but
simply holds a reference to an external network which was provided to the
void set_net ( dnn_trainer's constructor.
const net_type& net
);
/*!
ensures
- #get_net() == net
!*/ !*/
void set_solver ( void set_solver (
...@@ -275,7 +249,23 @@ namespace dlib ...@@ -275,7 +249,23 @@ namespace dlib
- This object will not print anything to standard out - This object will not print anything to standard out
!*/ !*/
const net_type& train ( void set_synchronization_file (
const std::string& filename,
std::chrono::seconds time_between_syncs = std::chrono::minutes(15)
);
/*!
ensures
- While training is running, either via train() or repeated calls to
train_one_step(), this object will save its entire state, including the
state of get_net(), to disk in the file named filename every
time_between_syncs seconds.
- if the filename file already exists then the state of this trainer will
be loaded from that file by this call to set_synchronization_file().
This allows you to resume a training session which was previously
interrupted.
!*/
void train (
const std::vector<input_type>& data, const std::vector<input_type>& data,
const std::vector<label_type>& labels const std::vector<label_type>& labels
); );
...@@ -292,22 +282,17 @@ namespace dlib ...@@ -292,22 +282,17 @@ namespace dlib
get_max_num_epochs() training epochs have been executes. get_max_num_epochs() training epochs have been executes.
- Each layer in the network will be optimized by its corresponding solver - Each layer in the network will be optimized by its corresponding solver
in get_solvers(). in get_solvers().
- returns #get_net()
(i.e. the trained network can also be accessed by calling get_net() after
train() finishes executing)
- Each call to train DOES NOT reinitialize the state of get_net() or - Each call to train DOES NOT reinitialize the state of get_net() or
get_solvers(). That is, the state of the solvers and network contained get_solvers(). That is, the existing state of the solvers and network is
inside this trainer is the starting point for the optimization each time the starting point for the optimization each time train() is called. In
train() is called. For example, calling train() 1 time and having it particular, if you use the set_synchronization_file() method you can
execute 100 epochs of training is equivalent to calling train() 10 times resume an interrupted train() call by simply calling train() again and it
and having it execute 10 epochs of training during each call. This also will pick up from the last synchronization point.
means you can serialize a trainer to disk and then, at a later date,
deserialize it and resume training your network where you left off.
- You can obtain the average loss value during the final training epoch by - You can obtain the average loss value during the final training epoch by
calling get_average_loss(). calling get_average_loss().
!*/ !*/
const net_type& train ( void train (
const std::vector<input_type>& data const std::vector<input_type>& data
); );
/*! /*!
...@@ -322,17 +307,12 @@ namespace dlib ...@@ -322,17 +307,12 @@ namespace dlib
get_max_num_epochs() training epochs have been executes. get_max_num_epochs() training epochs have been executes.
- Each layer in the network will be optimized by its corresponding solver - Each layer in the network will be optimized by its corresponding solver
in get_solvers(). in get_solvers().
- returns #get_net()
(i.e. the trained network can also be accessed by calling get_net() after
train() finishes executing)
- Each call to train DOES NOT reinitialize the state of get_net() or - Each call to train DOES NOT reinitialize the state of get_net() or
get_solvers(). That is, the state of the solvers and network contained get_solvers(). That is, the existing state of the solvers and network is
inside this trainer is the starting point for the optimization each time the starting point for the optimization each time train() is called. In
train() is called. For example, calling train() 1 time and having it particular, if you use the set_synchronization_file() method you can
execute 100 epochs of training is equivalent to calling train() 10 times resume an interrupted train() call by simply calling train() again and it
and having it execute 10 epochs of training during each call. This also will pick up from the last synchronization point.
means you can serialize a trainer to disk and then, at a later date,
deserialize it and resume training your network where you left off.
- You can obtain the average loss value during the final training epoch by - You can obtain the average loss value during the final training epoch by
calling get_average_loss(). calling get_average_loss().
!*/ !*/
...@@ -398,14 +378,6 @@ namespace dlib ...@@ -398,14 +378,6 @@ namespace dlib
}; };
template <typename T, typename U>
void serialize(const dnn_trainer<T,U>& item, std::ostream& out);
template <typename T, typename U>
void deserialize(dnn_trainer<T,U>& item, std::istream& in);
/*!
provides serialization support
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment