Commit e6437d7d authored by Davis King's avatar Davis King

Made the dnn_trainer check for convergence every iteration rather than only

once every few thousand iterations.
parent 0ecff0e6
......@@ -396,7 +396,9 @@ namespace dlib
{
double loss = net.update(next_job.t, next_job.labels.begin(), make_sstack(solvers),step_size);
rs.add(loss);
rg.add(loss);
previous_loss_values.push_back(loss);
if (previous_loss_values.size() > iter_between_step_size_adjust)
previous_loss_values.pop_front();
}
void run_update(job_t& next_job, const no_label_type&)
......@@ -404,7 +406,9 @@ namespace dlib
no_label_type pick_wich_run_update;
double loss = net.update(next_job.t, make_sstack(solvers), step_size);
rs.add(loss);
rg.add(loss);
previous_loss_values.push_back(loss);
if (previous_loss_values.size() > iter_between_step_size_adjust)
previous_loss_values.pop_front();
}
void thread() try
......@@ -422,13 +426,13 @@ namespace dlib
// If we have been running for a while then check if the loss is still
// dropping. If it isn't then we will reduce the step size.
if (rg.current_n() > iter_between_step_size_adjust)
if (previous_loss_values.size() >= iter_between_step_size_adjust)
{
if (rg.probability_gradient_greater_than(0) > 0.45)
if (probability_gradient_greater_than(previous_loss_values, 0) > 0.49)
{
step_size = step_size_shrink*step_size;
previous_loss_values.clear();
}
rg.clear();
}
}
}
......@@ -470,13 +474,13 @@ namespace dlib
friend void serialize(const dnn_trainer& item, std::ostream& out)
{
item.wait_for_thread_to_pause();
int version = 3;
int version = 4;
serialize(version, out);
size_t nl = dnn_trainer::num_layers;
serialize(nl, out);
serialize(item.rs, out);
serialize(item.rg, out);
serialize(item.previous_loss_values, out);
serialize(item.max_num_epochs, out);
serialize(item.mini_batch_size, out);
serialize(item.verbose, out);
......@@ -495,7 +499,7 @@ namespace dlib
item.wait_for_thread_to_pause();
int version = 0;
deserialize(version, in);
if (version != 3)
if (version != 4)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
size_t num_layers = 0;
......@@ -511,7 +515,7 @@ namespace dlib
double dtemp; long ltemp;
deserialize(item.rs, in);
deserialize(item.rg, in);
deserialize(item.previous_loss_values, in);
deserialize(item.max_num_epochs, in);
deserialize(item.mini_batch_size, in);
deserialize(item.verbose, in);
......@@ -562,7 +566,7 @@ namespace dlib
dlib::pipe<job_t> job_pipe;
running_stats<double> rs;
running_gradient rg;
std::deque<double> previous_loss_values;
unsigned long max_num_epochs;
size_t mini_batch_size;
bool verbose;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment