Commit ed683785 authored by Davis King's avatar Davis King

Added get_synchronization_file() and get_test_one_step_calls() to dnn_trainer.

Also added an operator<< for dnn_trainer that prints the parameters it's using.

These changes also break backwards compatibility with the previous
serialization format for dnn_trainer objects.
parent 9540ca23
......@@ -23,6 +23,7 @@
#include <exception>
#include <mutex>
#include "../dir_nav.h"
#include "../md5.h"
namespace dlib
{
......@@ -263,7 +264,7 @@ namespace dlib
sync_to_disk();
send_job(true, dbegin, dend, lbegin);
++train_one_step_calls;
++test_one_step_calls;
}
void test_one_step (
......@@ -285,7 +286,7 @@ namespace dlib
print_periodic_verbose_status();
sync_to_disk();
send_job(true, dbegin, dend);
++train_one_step_calls;
++test_one_step_calls;
}
void train (
......@@ -415,6 +416,12 @@ namespace dlib
deserialize(*this, fin);
}
const std::string& get_synchronization_file (
)
{
return sync_filename;
}
double get_average_loss (
) const
{
......@@ -564,6 +571,12 @@ namespace dlib
return train_one_step_calls;
}
unsigned long long get_test_one_step_calls (
) const
{
return test_one_step_calls;
}
private:
void record_test_loss(double loss)
......@@ -849,6 +862,7 @@ namespace dlib
epoch_iteration = 0;
epoch_pos = 0;
train_one_step_calls = 0;
test_one_step_calls = 0;
gradient_check_budget = 0;
lr_schedule_pos = 0;
......@@ -869,7 +883,7 @@ namespace dlib
friend void serialize(const dnn_trainer& item, std::ostream& out)
{
item.wait_for_thread_to_pause();
int version = 8;
int version = 9;
serialize(version, out);
size_t nl = dnn_trainer::num_layers;
......@@ -889,6 +903,7 @@ namespace dlib
serialize(item.epoch_iteration, out);
serialize(item.epoch_pos, out);
serialize(item.train_one_step_calls, out);
serialize(item.test_one_step_calls, out);
serialize(item.lr_schedule, out);
serialize(item.lr_schedule_pos, out);
serialize(item.test_iter_without_progress_thresh.load(), out);
......@@ -901,7 +916,7 @@ namespace dlib
item.wait_for_thread_to_pause();
int version = 0;
deserialize(version, in);
if (version != 8)
if (version != 9)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
size_t num_layers = 0;
......@@ -931,6 +946,7 @@ namespace dlib
deserialize(item.epoch_iteration, in);
deserialize(item.epoch_pos, in);
deserialize(item.train_one_step_calls, in);
deserialize(item.test_one_step_calls, in);
deserialize(item.lr_schedule, in);
deserialize(item.lr_schedule_pos, in);
deserialize(ltemp, in); item.test_iter_without_progress_thresh = ltemp;
......@@ -1220,6 +1236,7 @@ namespace dlib
size_t epoch_pos;
std::chrono::time_point<std::chrono::system_clock> last_time;
unsigned long long train_one_step_calls;
unsigned long long test_one_step_calls;
matrix<double,0,1> lr_schedule;
long lr_schedule_pos;
unsigned long gradient_check_budget;
......@@ -1244,6 +1261,41 @@ namespace dlib
bool sync_file_reloaded;
};
// ----------------------------------------------------------------------------------------
template <
typename net_type,
typename solver_type
>
std::ostream& operator<< (
std::ostream& out,
dnn_trainer<net_type,solver_type>& trainer
)
{
using std::endl;
out << "dnn_trainer details: \n";
out << " net_type::num_layers: " << net_type::num_layers << endl;
out << " net architecture hash: " << md5(cast_to_string(trainer.get_net())) << endl;
out << " loss: " << trainer.get_net().loss_details() << endl;
out << " synchronization file: " << trainer.get_synchronization_file() << endl;
out << " trainer.get_solvers()[0]: " << trainer.get_solvers()[0] << endl;
auto sched = trainer.get_learning_rate_schedule();
if (sched.size() != 0)
{
out << " using explicit user-supplied learning rate schedule" << endl;
}
else
{
out << " learning rate: "<< trainer.get_learning_rate() << endl;
out << " learning rate shrink factor: "<< trainer.get_learning_rate_shrink_factor() << endl;
out << " min learning rate: "<< trainer.get_min_learning_rate() << endl;
out << " iterations without progress threshold: "<< trainer.get_iterations_without_progress_threshold() << endl;
out << " test iterations without progress threshold: "<< trainer.get_test_iterations_without_progress_threshold() << endl;
}
return out;
}
// ----------------------------------------------------------------------------------------
}
......
......@@ -80,6 +80,8 @@ namespace dlib
- #get_learning_rate_shrink_factor() == 0.1
- #get_learning_rate_schedule().size() == 0
- #get_train_one_step_calls() == 0
- #get_test_one_step_calls() == 0
- #get_synchronization_file() == ""
- if (cuda_extra_devices.size() > 0) then
- This object will use multiple graphics cards to run the learning
algorithms. In particular, it will always use whatever device is
......@@ -311,6 +313,13 @@ namespace dlib
- returns the number of times train_one_step() has been called.
!*/
unsigned long long get_test_one_step_calls (
) const;
/*!
ensures
- returns the number of times test_one_step() has been called.
!*/
void be_verbose (
);
/*!
......@@ -332,6 +341,7 @@ namespace dlib
);
/*!
ensures
- #get_synchronization_file() == filename
- While training is running, either via train() or repeated calls to
train_one_step(), this object will save its entire state, including the
state of get_net(), to disk in the file named filename every
......@@ -349,6 +359,14 @@ namespace dlib
load from the newest of the two possible files.
!*/
const std::string& get_synchronization_file (
);
/*!
ensures
- Returns the name of the file the dnn_trainer will periodically save it's
state to. If the return value is "" then synchronization is disabled.
!*/
void train (
const std::vector<input_type>& data,
const std::vector<training_label_type>& labels
......@@ -580,6 +598,7 @@ namespace dlib
this function you should call get_net() before you touch the net object
from the calling thread to ensure no other threads are still accessing
the network.
- #get_test_one_step_calls() == get_test_one_step_calls() + 1.
!*/
template <
......@@ -611,6 +630,7 @@ namespace dlib
this function you should call get_net() before you touch the net object
from the calling thread to ensure no other threads are still accessing
the network.
- #get_test_one_step_calls() == get_test_one_step_calls() + 1.
!*/
void test_one_step (
......@@ -635,6 +655,7 @@ namespace dlib
this function you should call get_net() before you touch the net object
from the calling thread to ensure no other threads are still accessing
the network.
- #get_test_one_step_calls() == get_test_one_step_calls() + 1.
!*/
template <
......@@ -663,6 +684,7 @@ namespace dlib
this function you should call get_net() before you touch the net object
from the calling thread to ensure no other threads are still accessing
the network.
- #get_test_one_step_calls() == get_test_one_step_calls() + 1.
!*/
void set_test_iterations_without_progress_threshold (
......@@ -710,6 +732,21 @@ namespace dlib
};
// ----------------------------------------------------------------------------------------
template <
typename net_type,
typename solver_type
>
std::ostream& operator<< (
std::ostream& out,
dnn_trainer<net_type,solver_type>& trainer
);
/*!
ensures
- Prints a log of the current parameters of trainer to out.
!*/
// ----------------------------------------------------------------------------------------
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment