Commit 52627635 authored by Davis King's avatar Davis King

Made the dnn_trainer's detection and backtracking from situations with

increasing loss more robust.  Now it will never get into a situation where it
backtracks over and over.  Instead, it will only backtrack a few times in a row
before just letting SGD run unimpeded.
parent b19e139d
......@@ -743,6 +743,9 @@ namespace dlib
main_iteration_counter = 0;
main_iteration_counter_at_last_disk_sync = 0;
prob_loss_increasing_thresh_default_value = 0.99;
prob_loss_increasing_thresh_max_value = 0.99999;
prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value;
start();
}
......@@ -853,8 +856,6 @@ namespace dlib
// previously saved state in the hopes that the problem won't reoccur.
if (loss_increased_since_last_disk_sync())
{
// reload from the previous sync file. The file should exist since we
// checked that main_iteration_counter_at_last_disk_sync != 0.
std::ifstream fin(sync_filename, std::ios::binary);
deserialize(*this, fin);
if (verbose)
......@@ -881,12 +882,12 @@ namespace dlib
}
}
bool loss_increased_since_last_disk_sync() const
bool loss_increased_since_last_disk_sync()
{
size_t gradient_updates_since_last_sync = main_iteration_counter - main_iteration_counter_at_last_disk_sync;
// if we haven't synced anything to disk yet then return false.
if (main_iteration_counter_at_last_disk_sync == 0)
if (!std::ifstream(sync_filename, std::ios::binary))
return false;
for (auto x : previous_loss_values)
......@@ -897,7 +898,8 @@ namespace dlib
return true;
}
// if we haven't seen much data yet then just say false.
// if we haven't seen much data yet then just say false. Or, alternatively, if
// it's been too long since the last sync then don't reload either.
if (gradient_updates_since_last_sync < 30 || previous_loss_values.size() < 2*gradient_updates_since_last_sync)
return false;
......@@ -908,7 +910,25 @@ namespace dlib
g.add(previous_loss_values[i]);
// if the loss is very likely to be increasing then return true
return g.probability_gradient_greater_than(0) > 0.99;
const double prob = g.probability_gradient_greater_than(0);
if (prob > prob_loss_increasing_thresh && prob_loss_increasing_thresh <= prob_loss_increasing_thresh_max_value)
{
// Exponentially decay the threshold towards 1 so that if we keep finding
// the loss to be increasing over and over we will make the test
// progressively harder and harder until it fails, therefore ensuring we
// can't get stuck reloading from a previous state over and over.
prob_loss_increasing_thresh = 0.1*prob_loss_increasing_thresh + 0.9*1;
return true;
}
else
{
// decay back to the default threshold
prob_loss_increasing_thresh = std::pow(prob_loss_increasing_thresh, 10.0);
// but don't decay below the default value
prob_loss_increasing_thresh = std::max(prob_loss_increasing_thresh, prob_loss_increasing_thresh_default_value);
return false;
}
}
......@@ -1043,9 +1063,12 @@ namespace dlib
std::rethrow_exception(eptr);
}
// These two variables are not serialized
// These 5 variables are not serialized
size_t main_iteration_counter;
size_t main_iteration_counter_at_last_disk_sync;
double prob_loss_increasing_thresh_default_value;
double prob_loss_increasing_thresh_max_value;
double prob_loss_increasing_thresh;
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment