Commit b974a575 authored by Davis King's avatar Davis King

Added set_learning_rate_schedule() to dnn_trainer.

parent 13cc545d
...@@ -194,9 +194,8 @@ namespace dlib ...@@ -194,9 +194,8 @@ namespace dlib
last_time = now_time; last_time = now_time;
std::cout << "step#: " << rpad(cast_to_string(train_one_step_calls),epoch_string_pad) << " " std::cout << "step#: " << rpad(cast_to_string(train_one_step_calls),epoch_string_pad) << " "
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " " << "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " " << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
<< "steps without apparent progress: " << steps_without_progress print_progress();
<< std::endl;
clear_average_loss(); clear_average_loss();
} }
} }
...@@ -220,9 +219,8 @@ namespace dlib ...@@ -220,9 +219,8 @@ namespace dlib
last_time = now_time; last_time = now_time;
std::cout << "step#: " << rpad(cast_to_string(train_one_step_calls),epoch_string_pad) << " " std::cout << "step#: " << rpad(cast_to_string(train_one_step_calls),epoch_string_pad) << " "
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " " << "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " " << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
<< "steps without apparent progress: " << steps_without_progress print_progress();
<< std::endl;
clear_average_loss(); clear_average_loss();
} }
} }
...@@ -260,9 +258,8 @@ namespace dlib ...@@ -260,9 +258,8 @@ namespace dlib
auto iter = epoch_iteration + epoch_pos/(double)data.size(); auto iter = epoch_iteration + epoch_pos/(double)data.size();
std::cout << "epoch: " << rpad(cast_to_string(iter),epoch_string_pad) << " " std::cout << "epoch: " << rpad(cast_to_string(iter),epoch_string_pad) << " "
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " " << "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " " << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
<< "steps without apparent progress: " << steps_without_progress print_progress();
<< std::endl;
} }
} }
...@@ -280,9 +277,8 @@ namespace dlib ...@@ -280,9 +277,8 @@ namespace dlib
// are for full epoch status statements. // are for full epoch status statements.
std::cout << "Epoch: " << rpad(cast_to_string(epoch_iteration+1),epoch_string_pad) << " " std::cout << "Epoch: " << rpad(cast_to_string(epoch_iteration+1),epoch_string_pad) << " "
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " " << "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " " << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
<< "steps without apparent progress: " << steps_without_progress print_progress();
<< std::endl;
} }
} }
wait_for_thread_to_pause(); wait_for_thread_to_pause();
...@@ -322,9 +318,8 @@ namespace dlib ...@@ -322,9 +318,8 @@ namespace dlib
auto iter = epoch_iteration + epoch_pos/(double)data.size(); auto iter = epoch_iteration + epoch_pos/(double)data.size();
std::cout << "epoch: " << rpad(cast_to_string(iter),epoch_string_pad) << " " std::cout << "epoch: " << rpad(cast_to_string(iter),epoch_string_pad) << " "
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " " << "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " " << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
<< "steps without apparent progress: " << steps_without_progress print_progress();
<< std::endl;
} }
} }
...@@ -341,9 +336,8 @@ namespace dlib ...@@ -341,9 +336,8 @@ namespace dlib
// are for full epoch status statements. // are for full epoch status statements.
std::cout << "Epoch: " << rpad(cast_to_string(epoch_iteration+1),epoch_string_pad) << " " std::cout << "Epoch: " << rpad(cast_to_string(epoch_iteration+1),epoch_string_pad) << " "
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " " << "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " " << "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
<< "steps without apparent progress: " << steps_without_progress print_progress();
<< std::endl;
} }
} }
wait_for_thread_to_pause(); wait_for_thread_to_pause();
...@@ -389,6 +383,7 @@ namespace dlib ...@@ -389,6 +383,7 @@ namespace dlib
if (learning_rate != lr) if (learning_rate != lr)
previous_loss_values.clear(); previous_loss_values.clear();
learning_rate = lr; learning_rate = lr;
lr_schedule.set_size(0);
} }
double get_learning_rate( double get_learning_rate(
...@@ -402,6 +397,8 @@ namespace dlib ...@@ -402,6 +397,8 @@ namespace dlib
) )
{ {
DLIB_CASSERT(lr > 0,""); DLIB_CASSERT(lr > 0,"");
wait_for_thread_to_pause();
lr_schedule.set_size(0);
min_learning_rate = lr; min_learning_rate = lr;
} }
...@@ -411,10 +408,32 @@ namespace dlib ...@@ -411,10 +408,32 @@ namespace dlib
return min_learning_rate; return min_learning_rate;
} }
template <typename EXP>
void set_learning_rate_schedule (
const matrix_exp<EXP>& schedule
)
{
DLIB_CASSERT(schedule.size() > 0,"");
DLIB_CASSERT(min(schedule) > 0,"");
set_learning_rate(schedule(0,0));
set_min_learning_rate(min(schedule));
set_learning_rate_shrink_amount(1);
lr_schedule = matrix_cast<double>(reshape_to_column_vector(schedule));
lr_schedule_pos = 0;
}
const matrix<double,0,1>& get_learning_rate_schedule (
) const
{
return lr_schedule;
}
void set_iterations_without_progress_threshold ( void set_iterations_without_progress_threshold (
unsigned long thresh unsigned long thresh
) )
{ {
wait_for_thread_to_pause();
lr_schedule.set_size(0);
iter_without_progress_thresh = thresh; iter_without_progress_thresh = thresh;
} }
...@@ -429,6 +448,8 @@ namespace dlib ...@@ -429,6 +448,8 @@ namespace dlib
) )
{ {
DLIB_CASSERT(0 < shrink && shrink <= 1,""); DLIB_CASSERT(0 < shrink && shrink <= 1,"");
wait_for_thread_to_pause();
lr_schedule.set_size(0);
learning_rate_shrink = shrink; learning_rate_shrink = shrink;
} }
...@@ -608,6 +629,13 @@ namespace dlib ...@@ -608,6 +629,13 @@ namespace dlib
previous_loss_values.clear(); previous_loss_values.clear();
} }
} }
else if (lr_schedule.size() != 0) // or use the learning rate schedule if we have one.
{
if (lr_schedule_pos < lr_schedule.size())
learning_rate = lr_schedule(lr_schedule_pos++);
else
learning_rate = lr_schedule(lr_schedule.size()-1)*0.99;
}
} }
} }
catch(std::exception& e) catch(std::exception& e)
...@@ -639,6 +667,7 @@ namespace dlib ...@@ -639,6 +667,7 @@ namespace dlib
epoch_pos = 0; epoch_pos = 0;
train_one_step_calls = 0; train_one_step_calls = 0;
gradient_check_budget = 0; gradient_check_budget = 0;
lr_schedule_pos = 0;
start(); start();
} }
...@@ -649,7 +678,7 @@ namespace dlib ...@@ -649,7 +678,7 @@ namespace dlib
friend void serialize(const dnn_trainer& item, std::ostream& out) friend void serialize(const dnn_trainer& item, std::ostream& out)
{ {
item.wait_for_thread_to_pause(); item.wait_for_thread_to_pause();
int version = 6; int version = 7;
serialize(version, out); serialize(version, out);
size_t nl = dnn_trainer::num_layers; size_t nl = dnn_trainer::num_layers;
...@@ -669,13 +698,15 @@ namespace dlib ...@@ -669,13 +698,15 @@ namespace dlib
serialize(item.epoch_iteration, out); serialize(item.epoch_iteration, out);
serialize(item.epoch_pos, out); serialize(item.epoch_pos, out);
serialize(item.train_one_step_calls, out); serialize(item.train_one_step_calls, out);
serialize(item.lr_schedule, out);
serialize(item.lr_schedule_pos, out);
} }
friend void deserialize(dnn_trainer& item, std::istream& in) friend void deserialize(dnn_trainer& item, std::istream& in)
{ {
item.wait_for_thread_to_pause(); item.wait_for_thread_to_pause();
int version = 0; int version = 0;
deserialize(version, in); deserialize(version, in);
if (version != 6) if (version != 6 && version != 7)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer."); throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
size_t num_layers = 0; size_t num_layers = 0;
...@@ -705,6 +736,16 @@ namespace dlib ...@@ -705,6 +736,16 @@ namespace dlib
deserialize(item.epoch_iteration, in); deserialize(item.epoch_iteration, in);
deserialize(item.epoch_pos, in); deserialize(item.epoch_pos, in);
deserialize(item.train_one_step_calls, in); deserialize(item.train_one_step_calls, in);
if (version == 7)
{
deserialize(item.lr_schedule, in);
deserialize(item.lr_schedule_pos, in);
}
else
{
item.lr_schedule.set_size(0);
item.lr_schedule_pos = 0;
}
if (item.devices.size() > 1) if (item.devices.size() > 1)
{ {
...@@ -834,6 +875,21 @@ namespace dlib ...@@ -834,6 +875,21 @@ namespace dlib
send_job(dbegin, dend, nothing); send_job(dbegin, dend, nothing);
} }
void print_progress()
{
if (lr_schedule.size() == 0)
{
std::cout << "steps without apparent progress: " << steps_without_progress;
}
else
{
std::ostringstream sout;
sout << "percent complete: " << std::fixed << std::setprecision(2) << 100.0*lr_schedule_pos/(double)lr_schedule.size() << "%";
std::cout << sout.str();
}
std::cout << std::endl;
}
std::vector<std::shared_ptr<device_data>> devices; std::vector<std::shared_ptr<device_data>> devices;
dlib::pipe<job_t> job_pipe; dlib::pipe<job_t> job_pipe;
job_t job; job_t job;
...@@ -857,10 +913,11 @@ namespace dlib ...@@ -857,10 +913,11 @@ namespace dlib
size_t epoch_pos; size_t epoch_pos;
std::chrono::time_point<std::chrono::system_clock> last_time; std::chrono::time_point<std::chrono::system_clock> last_time;
unsigned long long train_one_step_calls; unsigned long long train_one_step_calls;
matrix<double,0,1> lr_schedule;
long lr_schedule_pos;
unsigned long gradient_check_budget; unsigned long gradient_check_budget;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -72,6 +72,7 @@ namespace dlib ...@@ -72,6 +72,7 @@ namespace dlib
- #get_min_learning_rate() == 1e-5 - #get_min_learning_rate() == 1e-5
- #get_iterations_without_progress_threshold() == 2000 - #get_iterations_without_progress_threshold() == 2000
- #get_learning_rate_shrink() == 0.1 - #get_learning_rate_shrink() == 0.1
- #get_learning_rate_schedule().size() == 0
- if (cuda_extra_devices.size() > 0) then - if (cuda_extra_devices.size() > 0) then
- This object will use multiple graphics cards to run the learning - This object will use multiple graphics cards to run the learning
algorithms. In particular, it will always use whatever device is algorithms. In particular, it will always use whatever device is
...@@ -152,6 +153,7 @@ namespace dlib ...@@ -152,6 +153,7 @@ namespace dlib
- lr > 0 - lr > 0
ensures ensures
- #get_learning_rate() == lr - #get_learning_rate() == lr
- #get_learning_rate_schedule().size() == 0
- This function blocks until all threads inside the dnn_trainer have - This function blocks until all threads inside the dnn_trainer have
stopped touching the net. stopped touching the net.
!*/ !*/
...@@ -164,7 +166,9 @@ namespace dlib ...@@ -164,7 +166,9 @@ namespace dlib
of each layer in the network. It does this by outputting a step vector of each layer in the network. It does this by outputting a step vector
that, when added to the parameters, will hopefully result in improved that, when added to the parameters, will hopefully result in improved
network performance. The learning rate is one of the inputs to the network performance. The learning rate is one of the inputs to the
solver and influences the size of this step vector. solver and influences the size of this step vector. This function
returns the current learning rate, that is, the learning rate that will
be used during the next training step.
!*/ !*/
void set_min_learning_rate ( void set_min_learning_rate (
...@@ -175,6 +179,9 @@ namespace dlib ...@@ -175,6 +179,9 @@ namespace dlib
- lr > 0 - lr > 0
ensures ensures
- #get_min_learning_rate() == lr - #get_min_learning_rate() == lr
- #get_learning_rate_schedule().size() == 0
- This function blocks until all threads inside the dnn_trainer have
stopped touching the net.
!*/ !*/
double get_min_learning_rate ( double get_min_learning_rate (
...@@ -191,12 +198,49 @@ namespace dlib ...@@ -191,12 +198,49 @@ namespace dlib
learning rate will drop infinitely close to 0 if you run long enough. learning rate will drop infinitely close to 0 if you run long enough.
!*/ !*/
template <typename EXP>
void set_learning_rate_schedule (
const matrix_exp<EXP>& schedule
);
/*!
requires
- schedule.size() > 0
- min(schedule) > 0
ensures
- #get_learning_rate_schedule() == reshape_to_column_vector(schedule)
- #get_learning_rate() == schedule(0,0)
- #get_min_learning_rate() == min(schedule)
- #set_learning_rate_shrink_amount() == 1
!*/
const matrix<double,0,1>& get_learning_rate_schedule (
) const;
/*!
ensures
- if (this function returns a non-empty matrix) then
- This trainer will use an explicit learning rate schedule defined by
the learning rate values in get_learning_rate_schedule(). For
example, if get_learning_rate_schedule() returned {0.1, 0.09, 0.08,
0.07, 0.6} then the first training mini-batch would use a learning
rate of 0.1, then the next training mini-batch uses 0.09, and then
0.8, and so on until the end of the schedule is reached.
If you continue to run training after the end of the schedule has
been reached then the learning rate will be fixed to 0.99 times the
final value. So in our example, eventually the learning rate would
be fixed to 0.99*0.6. This allows you to test if we have reached the
end of the schedule by checking if get_learning_rate() >= 0.6.
!*/
void set_iterations_without_progress_threshold ( void set_iterations_without_progress_threshold (
unsigned long thresh unsigned long thresh
); );
/*! /*!
ensures ensures
- #get_iterations_without_progress_threshold() == thresh - #get_iterations_without_progress_threshold() == thresh
- #get_learning_rate_schedule().size() == 0
- This function blocks until all threads inside the dnn_trainer have
stopped touching the net.
!*/ !*/
unsigned long get_iterations_without_progress_threshold ( unsigned long get_iterations_without_progress_threshold (
...@@ -225,6 +269,9 @@ namespace dlib ...@@ -225,6 +269,9 @@ namespace dlib
- 0 < shrink && shrink <= 1 - 0 < shrink && shrink <= 1
ensures ensures
- #get_learning_rate_shrink() == shrink - #get_learning_rate_shrink() == shrink
- #get_learning_rate_schedule().size() == 0
- This function blocks until all threads inside the dnn_trainer have
stopped touching the net.
!*/ !*/
double get_learning_rate_shrink ( double get_learning_rate_shrink (
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment