Commit 68412221 authored by Davis King's avatar Davis King

Improved the dnn_trainer. In particular, it no longer makes a copy of the

network (which would needlessly double VRAM usage).  I also added a
set_synchronization_file() method so you can tell it to automatically
synchronize itself to disk every so often during training.  This makes resuming
an interrupted training session trivially easy.
parent 4189386d
This diff is collapsed.
......@@ -6,6 +6,7 @@
#include "core_abstract.h"
#include "solvers_abstract.h"
#include <vector>
#include <chrono>
namespace dlib
......@@ -39,43 +40,21 @@ namespace dlib
typedef typename net_type::label_type label_type;
typedef typename net_type::input_type input_type;
const static size_t num_layers = net_type::num_layers;
dnn_trainer(
);
/*!
ensures
- #get_net() == a default initialized net_type object.
- #get_solvers() == a set of default initialized solvers.
- #get_max_num_epochs() == 10000
- #get_mini_batch_size() == 128
- #get_step_size() == 1
- #get_min_step_size() == 1e-4
- #get_iterations_between_step_size_adjust() == 2000
- #get_step_size_shrink() == 0.1
!*/
explicit dnn_trainer(
const net_type& net
);
/*!
ensures
- #get_net() == net
- #get_solvers() == a set of default initialized solvers.
- #get_max_num_epochs() == 10000
- #get_mini_batch_size() == 128
- #get_step_size() == 1
- #get_min_step_size() == 1e-4
- #get_iterations_between_step_size_adjust() == 2000
- #get_step_size_shrink() == 0.1
!*/
dnn_trainer() = delete;
dnn_trainer(const dnn_trainer&) = delete;
dnn_trainer(
const net_type& net,
const solver_type& solver
net_type& net,
const solver_type& solver = solver_type()
);
/*!
ensures
- #get_net() == net
- &#get_net() == &net
(i.e. The dnn_trainer holds a reference to net, it does not copy it.
Therefore, you must ensure net has a lifetime at least as long as the
dnn_trainer).
- #get_solvers() == a set of solvers that are all initialized with the
provided solver instance.
- #get_max_num_epochs() == 10000
......@@ -86,20 +65,15 @@ namespace dlib
- #get_step_size_shrink() == 0.1
!*/
const net_type& get_net (
net_type& get_net (
) const;
/*!
ensures
- returns the neural network object in this trainer. This is the network
that is optimized when you call train().
!*/
void set_net (
const net_type& net
);
/*!
ensures
- #get_net() == net
- returns the neural network object used by this trainer. This is the
network that is optimized when you call train() or train_one_step().
Recall that the dnn_trainer doesn't contain the net_type object but
simply holds a reference to an external network which was provided to the
dnn_trainer's constructor.
!*/
void set_solver (
......@@ -275,7 +249,23 @@ namespace dlib
- This object will not print anything to standard out
!*/
const net_type& train (
void set_synchronization_file (
const std::string& filename,
std::chrono::seconds time_between_syncs = std::chrono::minutes(15)
);
/*!
ensures
- While training is running, either via train() or repeated calls to
train_one_step(), this object will save its entire state, including the
state of get_net(), to disk in the file named filename every
time_between_syncs seconds.
- if the filename file already exists then the state of this trainer will
be loaded from that file by this call to set_synchronization_file().
This allows you to resume a training session which was previously
interrupted.
!*/
void train (
const std::vector<input_type>& data,
const std::vector<label_type>& labels
);
......@@ -292,22 +282,17 @@ namespace dlib
get_max_num_epochs() training epochs have been executes.
- Each layer in the network will be optimized by its corresponding solver
in get_solvers().
- returns #get_net()
(i.e. the trained network can also be accessed by calling get_net() after
train() finishes executing)
- Each call to train DOES NOT reinitialize the state of get_net() or
get_solvers(). That is, the state of the solvers and network contained
inside this trainer is the starting point for the optimization each time
train() is called. For example, calling train() 1 time and having it
execute 100 epochs of training is equivalent to calling train() 10 times
and having it execute 10 epochs of training during each call. This also
means you can serialize a trainer to disk and then, at a later date,
deserialize it and resume training your network where you left off.
get_solvers(). That is, the existing state of the solvers and network is
the starting point for the optimization each time train() is called. In
particular, if you use the set_synchronization_file() method you can
resume an interrupted train() call by simply calling train() again and it
will pick up from the last synchronization point.
- You can obtain the average loss value during the final training epoch by
calling get_average_loss().
!*/
const net_type& train (
void train (
const std::vector<input_type>& data
);
/*!
......@@ -322,17 +307,12 @@ namespace dlib
get_max_num_epochs() training epochs have been executes.
- Each layer in the network will be optimized by its corresponding solver
in get_solvers().
- returns #get_net()
(i.e. the trained network can also be accessed by calling get_net() after
train() finishes executing)
- Each call to train DOES NOT reinitialize the state of get_net() or
get_solvers(). That is, the state of the solvers and network contained
inside this trainer is the starting point for the optimization each time
train() is called. For example, calling train() 1 time and having it
execute 100 epochs of training is equivalent to calling train() 10 times
and having it execute 10 epochs of training during each call. This also
means you can serialize a trainer to disk and then, at a later date,
deserialize it and resume training your network where you left off.
get_solvers(). That is, the existing state of the solvers and network is
the starting point for the optimization each time train() is called. In
particular, if you use the set_synchronization_file() method you can
resume an interrupted train() call by simply calling train() again and it
will pick up from the last synchronization point.
- You can obtain the average loss value during the final training epoch by
calling get_average_loss().
!*/
......@@ -398,14 +378,6 @@ namespace dlib
};
template <typename T, typename U>
void serialize(const dnn_trainer<T,U>& item, std::ostream& out);
template <typename T, typename U>
void deserialize(dnn_trainer<T,U>& item, std::istream& in);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment