Commit fc8a335f authored by Davis King's avatar Davis King

Made the dnn_trainer check if the loss has been increasing before it saves the

state to disk. If it detects that the loss has been going up then instead of
saving to disk it recalls the previously good state.  This way, if we hit a
bad mini-batch during training which negatively effects the model in a
significant way, the dnn_trainer will automatically revert back to an earlier
good state.
parent 6c66d22b
......@@ -585,10 +585,10 @@ namespace dlib
tp.push_back(std::make_shared<thread_pool>(1));
size_t iteration = 0;
main_iteration_counter = 0;
while(job_pipe.dequeue(next_job))
{
++iteration;
++main_iteration_counter;
// Call compute_parameter_gradients() and update_parameters() but pick the
// right version for unsupervised or supervised training based on the type
// of label_type.
......@@ -656,7 +656,7 @@ namespace dlib
// the different networks may be initialized differently when tensor data
// is first passed through them. So this code block deals with these
// issues.
if (devices.size() > 1 && iteration%2000 == 1)
if (devices.size() > 1 && main_iteration_counter%2000 == 1)
{
for (size_t i = 1; i < devices.size(); ++i)
{
......@@ -740,6 +740,9 @@ namespace dlib
train_one_step_calls = 0;
gradient_check_budget = 0;
lr_schedule_pos = 0;
main_iteration_counter = 0;
main_iteration_counter_at_last_disk_sync = 0;
start();
}
......@@ -844,6 +847,22 @@ namespace dlib
// compact network before saving to disk.
this->net.clean();
// if the loss has actually been going up since the last time we saved our
// state to disk then something has probably gone wrong in the
// optimization. So in this case we do the opposite and recall the
// previously saved state in the hopes that the problem won't reoccur.
if (loss_increased_since_last_disk_sync())
{
// reload from the previous sync file. The file should exist since we
// checked that main_iteration_counter_at_last_disk_sync != 0.
std::ifstream fin(sync_filename, std::ios::binary);
deserialize(*this, fin);
if (verbose)
std::cout << "Loss has been increasing, reloading saved state from " << sync_filename << std::endl;
}
else
{
// save our state to a temp file
const std::string tempfile = sync_filename + ".tmp";
serialize(tempfile) << *this;
......@@ -853,10 +872,35 @@ namespace dlib
std::remove(sync_filename.c_str());
std::rename(tempfile.c_str(), sync_filename.c_str());
last_sync_time = std::chrono::system_clock::now();
if (verbose)
std::cout << "Saved state to " << sync_filename << std::endl;
}
last_sync_time = std::chrono::system_clock::now();
main_iteration_counter_at_last_disk_sync = main_iteration_counter;
}
}
bool loss_increased_since_last_disk_sync() const
{
size_t gradient_updates_since_last_sync = main_iteration_counter - main_iteration_counter_at_last_disk_sync;
// if we haven't synced anything to disk yet then return false.
if (main_iteration_counter_at_last_disk_sync == 0)
return false;
// if we haven't seen much data yet then just say false.
if (gradient_updates_since_last_sync < 30 || previous_loss_values.size() < 2*gradient_updates_since_last_sync)
return false;
// Now look at the data since a little before the last disk sync. We will
// check if the loss is getting bettor or worse.
running_gradient g;
for (size_t i = previous_loss_values.size() - 2*gradient_updates_since_last_sync; i < previous_loss_values.size(); ++i)
g.add(previous_loss_values[i]);
// if the loss is very likely to be increasing then return true
return g.probability_gradient_greater_than(0) > 0.99;
}
......@@ -991,6 +1035,10 @@ namespace dlib
std::rethrow_exception(eptr);
}
// These two variables are not serialized
size_t main_iteration_counter;
size_t main_iteration_counter_at_last_disk_sync;
};
// ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment