Commit b974a575 authored by Davis King's avatar Davis King

Added set_learning_rate_schedule() to dnn_trainer.

parent 13cc545d
......@@ -193,10 +193,9 @@ namespace dlib
{
last_time = now_time;
std::cout << "step#: " << rpad(cast_to_string(train_one_step_calls),epoch_string_pad) << " "
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " "
<< "steps without apparent progress: " << steps_without_progress
<< std::endl;
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
print_progress();
clear_average_loss();
}
}
......@@ -219,10 +218,9 @@ namespace dlib
{
last_time = now_time;
std::cout << "step#: " << rpad(cast_to_string(train_one_step_calls),epoch_string_pad) << " "
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " "
<< "steps without apparent progress: " << steps_without_progress
<< std::endl;
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
print_progress();
clear_average_loss();
}
}
......@@ -259,10 +257,9 @@ namespace dlib
last_time = now_time;
auto iter = epoch_iteration + epoch_pos/(double)data.size();
std::cout << "epoch: " << rpad(cast_to_string(iter),epoch_string_pad) << " "
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " "
<< "steps without apparent progress: " << steps_without_progress
<< std::endl;
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
print_progress();
}
}
......@@ -280,9 +277,8 @@ namespace dlib
// are for full epoch status statements.
std::cout << "Epoch: " << rpad(cast_to_string(epoch_iteration+1),epoch_string_pad) << " "
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " "
<< "steps without apparent progress: " << steps_without_progress
<< std::endl;
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
print_progress();
}
}
wait_for_thread_to_pause();
......@@ -321,10 +317,9 @@ namespace dlib
last_time = now_time;
auto iter = epoch_iteration + epoch_pos/(double)data.size();
std::cout << "epoch: " << rpad(cast_to_string(iter),epoch_string_pad) << " "
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " "
<< "steps without apparent progress: " << steps_without_progress
<< std::endl;
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
print_progress();
}
}
......@@ -341,9 +336,8 @@ namespace dlib
// are for full epoch status statements.
std::cout << "Epoch: " << rpad(cast_to_string(epoch_iteration+1),epoch_string_pad) << " "
<< "learning rate: " << rpad(cast_to_string(learning_rate),lr_string_pad) << " "
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " "
<< "steps without apparent progress: " << steps_without_progress
<< std::endl;
<< "average loss: " << rpad(cast_to_string(get_average_loss()),string_pad) << " ";
print_progress();
}
}
wait_for_thread_to_pause();
......@@ -389,6 +383,7 @@ namespace dlib
if (learning_rate != lr)
previous_loss_values.clear();
learning_rate = lr;
lr_schedule.set_size(0);
}
double get_learning_rate(
......@@ -402,6 +397,8 @@ namespace dlib
)
{
DLIB_CASSERT(lr > 0,"");
wait_for_thread_to_pause();
lr_schedule.set_size(0);
min_learning_rate = lr;
}
......@@ -411,10 +408,32 @@ namespace dlib
return min_learning_rate;
}
template <typename EXP>
void set_learning_rate_schedule (
const matrix_exp<EXP>& schedule
)
{
DLIB_CASSERT(schedule.size() > 0,"");
DLIB_CASSERT(min(schedule) > 0,"");
set_learning_rate(schedule(0,0));
set_min_learning_rate(min(schedule));
set_learning_rate_shrink_amount(1);
lr_schedule = matrix_cast<double>(reshape_to_column_vector(schedule));
lr_schedule_pos = 0;
}
const matrix<double,0,1>& get_learning_rate_schedule (
) const
{
return lr_schedule;
}
void set_iterations_without_progress_threshold (
unsigned long thresh
)
{
wait_for_thread_to_pause();
lr_schedule.set_size(0);
iter_without_progress_thresh = thresh;
}
......@@ -429,6 +448,8 @@ namespace dlib
)
{
DLIB_CASSERT(0 < shrink && shrink <= 1,"");
wait_for_thread_to_pause();
lr_schedule.set_size(0);
learning_rate_shrink = shrink;
}
......@@ -608,6 +629,13 @@ namespace dlib
previous_loss_values.clear();
}
}
else if (lr_schedule.size() != 0) // or use the learning rate schedule if we have one.
{
if (lr_schedule_pos < lr_schedule.size())
learning_rate = lr_schedule(lr_schedule_pos++);
else
learning_rate = lr_schedule(lr_schedule.size()-1)*0.99;
}
}
}
catch(std::exception& e)
......@@ -639,6 +667,7 @@ namespace dlib
epoch_pos = 0;
train_one_step_calls = 0;
gradient_check_budget = 0;
lr_schedule_pos = 0;
start();
}
......@@ -649,7 +678,7 @@ namespace dlib
friend void serialize(const dnn_trainer& item, std::ostream& out)
{
item.wait_for_thread_to_pause();
int version = 6;
int version = 7;
serialize(version, out);
size_t nl = dnn_trainer::num_layers;
......@@ -669,13 +698,15 @@ namespace dlib
serialize(item.epoch_iteration, out);
serialize(item.epoch_pos, out);
serialize(item.train_one_step_calls, out);
serialize(item.lr_schedule, out);
serialize(item.lr_schedule_pos, out);
}
friend void deserialize(dnn_trainer& item, std::istream& in)
{
item.wait_for_thread_to_pause();
int version = 0;
deserialize(version, in);
if (version != 6)
if (version != 6 && version != 7)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
size_t num_layers = 0;
......@@ -705,6 +736,16 @@ namespace dlib
deserialize(item.epoch_iteration, in);
deserialize(item.epoch_pos, in);
deserialize(item.train_one_step_calls, in);
if (version == 7)
{
deserialize(item.lr_schedule, in);
deserialize(item.lr_schedule_pos, in);
}
else
{
item.lr_schedule.set_size(0);
item.lr_schedule_pos = 0;
}
if (item.devices.size() > 1)
{
......@@ -834,6 +875,21 @@ namespace dlib
send_job(dbegin, dend, nothing);
}
void print_progress()
{
if (lr_schedule.size() == 0)
{
std::cout << "steps without apparent progress: " << steps_without_progress;
}
else
{
std::ostringstream sout;
sout << "percent complete: " << std::fixed << std::setprecision(2) << 100.0*lr_schedule_pos/(double)lr_schedule.size() << "%";
std::cout << sout.str();
}
std::cout << std::endl;
}
std::vector<std::shared_ptr<device_data>> devices;
dlib::pipe<job_t> job_pipe;
job_t job;
......@@ -857,10 +913,11 @@ namespace dlib
size_t epoch_pos;
std::chrono::time_point<std::chrono::system_clock> last_time;
unsigned long long train_one_step_calls;
matrix<double,0,1> lr_schedule;
long lr_schedule_pos;
unsigned long gradient_check_budget;
};
// ----------------------------------------------------------------------------------------
......
......@@ -72,6 +72,7 @@ namespace dlib
- #get_min_learning_rate() == 1e-5
- #get_iterations_without_progress_threshold() == 2000
- #get_learning_rate_shrink() == 0.1
- #get_learning_rate_schedule().size() == 0
- if (cuda_extra_devices.size() > 0) then
- This object will use multiple graphics cards to run the learning
algorithms. In particular, it will always use whatever device is
......@@ -152,6 +153,7 @@ namespace dlib
- lr > 0
ensures
- #get_learning_rate() == lr
- #get_learning_rate_schedule().size() == 0
- This function blocks until all threads inside the dnn_trainer have
stopped touching the net.
!*/
......@@ -164,7 +166,9 @@ namespace dlib
of each layer in the network. It does this by outputting a step vector
that, when added to the parameters, will hopefully result in improved
network performance. The learning rate is one of the inputs to the
solver and influences the size of this step vector.
solver and influences the size of this step vector. This function
returns the current learning rate, that is, the learning rate that will
be used during the next training step.
!*/
void set_min_learning_rate (
......@@ -175,6 +179,9 @@ namespace dlib
- lr > 0
ensures
- #get_min_learning_rate() == lr
- #get_learning_rate_schedule().size() == 0
- This function blocks until all threads inside the dnn_trainer have
stopped touching the net.
!*/
double get_min_learning_rate (
......@@ -191,12 +198,49 @@ namespace dlib
learning rate will drop infinitely close to 0 if you run long enough.
!*/
template <typename EXP>
void set_learning_rate_schedule (
const matrix_exp<EXP>& schedule
);
/*!
requires
- schedule.size() > 0
- min(schedule) > 0
ensures
- #get_learning_rate_schedule() == reshape_to_column_vector(schedule)
- #get_learning_rate() == schedule(0,0)
- #get_min_learning_rate() == min(schedule)
- #set_learning_rate_shrink_amount() == 1
!*/
const matrix<double,0,1>& get_learning_rate_schedule (
) const;
/*!
ensures
- if (this function returns a non-empty matrix) then
- This trainer will use an explicit learning rate schedule defined by
the learning rate values in get_learning_rate_schedule(). For
example, if get_learning_rate_schedule() returned {0.1, 0.09, 0.08,
0.07, 0.6} then the first training mini-batch would use a learning
rate of 0.1, then the next training mini-batch uses 0.09, and then
0.8, and so on until the end of the schedule is reached.
If you continue to run training after the end of the schedule has
been reached then the learning rate will be fixed to 0.99 times the
final value. So in our example, eventually the learning rate would
be fixed to 0.99*0.6. This allows you to test if we have reached the
end of the schedule by checking if get_learning_rate() >= 0.6.
!*/
void set_iterations_without_progress_threshold (
unsigned long thresh
);
/*!
ensures
- #get_iterations_without_progress_threshold() == thresh
- #get_learning_rate_schedule().size() == 0
- This function blocks until all threads inside the dnn_trainer have
stopped touching the net.
!*/
unsigned long get_iterations_without_progress_threshold (
......@@ -225,6 +269,9 @@ namespace dlib
- 0 < shrink && shrink <= 1
ensures
- #get_learning_rate_shrink() == shrink
- #get_learning_rate_schedule().size() == 0
- This function blocks until all threads inside the dnn_trainer have
stopped touching the net.
!*/
double get_learning_rate_shrink (
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment