Commit ed683785 authored by Davis King's avatar Davis King

Added get_synchronization_file() and get_test_one_step_calls() to dnn_trainer.

Also added an operator<< for dnn_trainer that prints the parameters it's using.

These changes also break backwards compatibility with the previous
serialization format for dnn_trainer objects.
parent 9540ca23
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <exception> #include <exception>
#include <mutex> #include <mutex>
#include "../dir_nav.h" #include "../dir_nav.h"
#include "../md5.h"
namespace dlib namespace dlib
{ {
...@@ -263,7 +264,7 @@ namespace dlib ...@@ -263,7 +264,7 @@ namespace dlib
sync_to_disk(); sync_to_disk();
send_job(true, dbegin, dend, lbegin); send_job(true, dbegin, dend, lbegin);
++train_one_step_calls; ++test_one_step_calls;
} }
void test_one_step ( void test_one_step (
...@@ -285,7 +286,7 @@ namespace dlib ...@@ -285,7 +286,7 @@ namespace dlib
print_periodic_verbose_status(); print_periodic_verbose_status();
sync_to_disk(); sync_to_disk();
send_job(true, dbegin, dend); send_job(true, dbegin, dend);
++train_one_step_calls; ++test_one_step_calls;
} }
void train ( void train (
...@@ -415,6 +416,12 @@ namespace dlib ...@@ -415,6 +416,12 @@ namespace dlib
deserialize(*this, fin); deserialize(*this, fin);
} }
const std::string& get_synchronization_file (
)
{
return sync_filename;
}
double get_average_loss ( double get_average_loss (
) const ) const
{ {
...@@ -564,6 +571,12 @@ namespace dlib ...@@ -564,6 +571,12 @@ namespace dlib
return train_one_step_calls; return train_one_step_calls;
} }
unsigned long long get_test_one_step_calls (
) const
{
return test_one_step_calls;
}
private: private:
void record_test_loss(double loss) void record_test_loss(double loss)
...@@ -849,6 +862,7 @@ namespace dlib ...@@ -849,6 +862,7 @@ namespace dlib
epoch_iteration = 0; epoch_iteration = 0;
epoch_pos = 0; epoch_pos = 0;
train_one_step_calls = 0; train_one_step_calls = 0;
test_one_step_calls = 0;
gradient_check_budget = 0; gradient_check_budget = 0;
lr_schedule_pos = 0; lr_schedule_pos = 0;
...@@ -869,7 +883,7 @@ namespace dlib ...@@ -869,7 +883,7 @@ namespace dlib
friend void serialize(const dnn_trainer& item, std::ostream& out) friend void serialize(const dnn_trainer& item, std::ostream& out)
{ {
item.wait_for_thread_to_pause(); item.wait_for_thread_to_pause();
int version = 8; int version = 9;
serialize(version, out); serialize(version, out);
size_t nl = dnn_trainer::num_layers; size_t nl = dnn_trainer::num_layers;
...@@ -889,6 +903,7 @@ namespace dlib ...@@ -889,6 +903,7 @@ namespace dlib
serialize(item.epoch_iteration, out); serialize(item.epoch_iteration, out);
serialize(item.epoch_pos, out); serialize(item.epoch_pos, out);
serialize(item.train_one_step_calls, out); serialize(item.train_one_step_calls, out);
serialize(item.test_one_step_calls, out);
serialize(item.lr_schedule, out); serialize(item.lr_schedule, out);
serialize(item.lr_schedule_pos, out); serialize(item.lr_schedule_pos, out);
serialize(item.test_iter_without_progress_thresh.load(), out); serialize(item.test_iter_without_progress_thresh.load(), out);
...@@ -901,7 +916,7 @@ namespace dlib ...@@ -901,7 +916,7 @@ namespace dlib
item.wait_for_thread_to_pause(); item.wait_for_thread_to_pause();
int version = 0; int version = 0;
deserialize(version, in); deserialize(version, in);
if (version != 8) if (version != 9)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer."); throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
size_t num_layers = 0; size_t num_layers = 0;
...@@ -931,6 +946,7 @@ namespace dlib ...@@ -931,6 +946,7 @@ namespace dlib
deserialize(item.epoch_iteration, in); deserialize(item.epoch_iteration, in);
deserialize(item.epoch_pos, in); deserialize(item.epoch_pos, in);
deserialize(item.train_one_step_calls, in); deserialize(item.train_one_step_calls, in);
deserialize(item.test_one_step_calls, in);
deserialize(item.lr_schedule, in); deserialize(item.lr_schedule, in);
deserialize(item.lr_schedule_pos, in); deserialize(item.lr_schedule_pos, in);
deserialize(ltemp, in); item.test_iter_without_progress_thresh = ltemp; deserialize(ltemp, in); item.test_iter_without_progress_thresh = ltemp;
...@@ -1220,6 +1236,7 @@ namespace dlib ...@@ -1220,6 +1236,7 @@ namespace dlib
size_t epoch_pos; size_t epoch_pos;
std::chrono::time_point<std::chrono::system_clock> last_time; std::chrono::time_point<std::chrono::system_clock> last_time;
unsigned long long train_one_step_calls; unsigned long long train_one_step_calls;
unsigned long long test_one_step_calls;
matrix<double,0,1> lr_schedule; matrix<double,0,1> lr_schedule;
long lr_schedule_pos; long lr_schedule_pos;
unsigned long gradient_check_budget; unsigned long gradient_check_budget;
...@@ -1244,6 +1261,41 @@ namespace dlib ...@@ -1244,6 +1261,41 @@ namespace dlib
bool sync_file_reloaded; bool sync_file_reloaded;
}; };
// ----------------------------------------------------------------------------------------
template <
typename net_type,
typename solver_type
>
std::ostream& operator<< (
std::ostream& out,
dnn_trainer<net_type,solver_type>& trainer
)
{
using std::endl;
out << "dnn_trainer details: \n";
out << " net_type::num_layers: " << net_type::num_layers << endl;
out << " net architecture hash: " << md5(cast_to_string(trainer.get_net())) << endl;
out << " loss: " << trainer.get_net().loss_details() << endl;
out << " synchronization file: " << trainer.get_synchronization_file() << endl;
out << " trainer.get_solvers()[0]: " << trainer.get_solvers()[0] << endl;
auto sched = trainer.get_learning_rate_schedule();
if (sched.size() != 0)
{
out << " using explicit user-supplied learning rate schedule" << endl;
}
else
{
out << " learning rate: "<< trainer.get_learning_rate() << endl;
out << " learning rate shrink factor: "<< trainer.get_learning_rate_shrink_factor() << endl;
out << " min learning rate: "<< trainer.get_min_learning_rate() << endl;
out << " iterations without progress threshold: "<< trainer.get_iterations_without_progress_threshold() << endl;
out << " test iterations without progress threshold: "<< trainer.get_test_iterations_without_progress_threshold() << endl;
}
return out;
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
...@@ -80,6 +80,8 @@ namespace dlib ...@@ -80,6 +80,8 @@ namespace dlib
- #get_learning_rate_shrink_factor() == 0.1 - #get_learning_rate_shrink_factor() == 0.1
- #get_learning_rate_schedule().size() == 0 - #get_learning_rate_schedule().size() == 0
- #get_train_one_step_calls() == 0 - #get_train_one_step_calls() == 0
- #get_test_one_step_calls() == 0
- #get_synchronization_file() == ""
- if (cuda_extra_devices.size() > 0) then - if (cuda_extra_devices.size() > 0) then
- This object will use multiple graphics cards to run the learning - This object will use multiple graphics cards to run the learning
algorithms. In particular, it will always use whatever device is algorithms. In particular, it will always use whatever device is
...@@ -311,6 +313,13 @@ namespace dlib ...@@ -311,6 +313,13 @@ namespace dlib
- returns the number of times train_one_step() has been called. - returns the number of times train_one_step() has been called.
!*/ !*/
unsigned long long get_test_one_step_calls (
) const;
/*!
ensures
- returns the number of times test_one_step() has been called.
!*/
void be_verbose ( void be_verbose (
); );
/*! /*!
...@@ -332,6 +341,7 @@ namespace dlib ...@@ -332,6 +341,7 @@ namespace dlib
); );
/*! /*!
ensures ensures
- #get_synchronization_file() == filename
- While training is running, either via train() or repeated calls to - While training is running, either via train() or repeated calls to
train_one_step(), this object will save its entire state, including the train_one_step(), this object will save its entire state, including the
state of get_net(), to disk in the file named filename every state of get_net(), to disk in the file named filename every
...@@ -349,6 +359,14 @@ namespace dlib ...@@ -349,6 +359,14 @@ namespace dlib
load from the newest of the two possible files. load from the newest of the two possible files.
!*/ !*/
const std::string& get_synchronization_file (
);
/*!
ensures
- Returns the name of the file the dnn_trainer will periodically save it's
state to. If the return value is "" then synchronization is disabled.
!*/
void train ( void train (
const std::vector<input_type>& data, const std::vector<input_type>& data,
const std::vector<training_label_type>& labels const std::vector<training_label_type>& labels
...@@ -580,6 +598,7 @@ namespace dlib ...@@ -580,6 +598,7 @@ namespace dlib
this function you should call get_net() before you touch the net object this function you should call get_net() before you touch the net object
from the calling thread to ensure no other threads are still accessing from the calling thread to ensure no other threads are still accessing
the network. the network.
- #get_test_one_step_calls() == get_test_one_step_calls() + 1.
!*/ !*/
template < template <
...@@ -611,6 +630,7 @@ namespace dlib ...@@ -611,6 +630,7 @@ namespace dlib
this function you should call get_net() before you touch the net object this function you should call get_net() before you touch the net object
from the calling thread to ensure no other threads are still accessing from the calling thread to ensure no other threads are still accessing
the network. the network.
- #get_test_one_step_calls() == get_test_one_step_calls() + 1.
!*/ !*/
void test_one_step ( void test_one_step (
...@@ -635,6 +655,7 @@ namespace dlib ...@@ -635,6 +655,7 @@ namespace dlib
this function you should call get_net() before you touch the net object this function you should call get_net() before you touch the net object
from the calling thread to ensure no other threads are still accessing from the calling thread to ensure no other threads are still accessing
the network. the network.
- #get_test_one_step_calls() == get_test_one_step_calls() + 1.
!*/ !*/
template < template <
...@@ -663,6 +684,7 @@ namespace dlib ...@@ -663,6 +684,7 @@ namespace dlib
this function you should call get_net() before you touch the net object this function you should call get_net() before you touch the net object
from the calling thread to ensure no other threads are still accessing from the calling thread to ensure no other threads are still accessing
the network. the network.
- #get_test_one_step_calls() == get_test_one_step_calls() + 1.
!*/ !*/
void set_test_iterations_without_progress_threshold ( void set_test_iterations_without_progress_threshold (
...@@ -710,6 +732,21 @@ namespace dlib ...@@ -710,6 +732,21 @@ namespace dlib
}; };
// ----------------------------------------------------------------------------------------
template <
typename net_type,
typename solver_type
>
std::ostream& operator<< (
std::ostream& out,
dnn_trainer<net_type,solver_type>& trainer
);
/*!
ensures
- Prints a log of the current parameters of trainer to out.
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment