Commit dd62b0e2 authored by Davis King's avatar Davis King

Made the dnn_trainer not forget all the previous loss values it knows about

when it determines that there have been a lot of steps without progress and
shrinks the learning rate.  Instead, it removes only the oldest 100.  The
problem with the old way of removing all the loss values in the history was
that if you set the steps without progress threshold to a really high number
you would often observe that the last few learning rate values were obviously
not making progress, however, since all the previous loss values were forgotten
the trainer needed to fully populate it's loss history from scratch before it
would figure this out.  This new style makes the trainer not waste time running
this excessive optimization of obviously useless mini-batches.
parent 618f1084
...@@ -699,7 +699,10 @@ namespace dlib ...@@ -699,7 +699,10 @@ namespace dlib
// optimization has flattened out, so drop the learning rate. // optimization has flattened out, so drop the learning rate.
learning_rate = learning_rate_shrink*learning_rate; learning_rate = learning_rate_shrink*learning_rate;
test_steps_without_progress = 0; test_steps_without_progress = 0;
test_previous_loss_values.clear(); // Empty out some of the previous loss values so that test_steps_without_progress
// will decrease below test_iter_without_progress_thresh.
for (int cnt = 0; cnt < previous_loss_values_dump_amount && test_previous_loss_values.size() > 0; ++cnt)
test_previous_loss_values.pop_front();
} }
} }
} }
...@@ -804,7 +807,7 @@ namespace dlib ...@@ -804,7 +807,7 @@ namespace dlib
// this because sometimes a mini-batch might be bad and cause the // this because sometimes a mini-batch might be bad and cause the
// loss to suddenly jump up, making count_steps_without_decrease() // loss to suddenly jump up, making count_steps_without_decrease()
// return a large number. But if we discard the top 10% of the // return a large number. But if we discard the top 10% of the
// values in previous_loss_values when we are robust to that kind // values in previous_loss_values then we are robust to that kind
// of noise. Another way of looking at it, if the reason // of noise. Another way of looking at it, if the reason
// count_steps_without_decrease() returns a large value is only // count_steps_without_decrease() returns a large value is only
// because the most recent loss values have suddenly been large, // because the most recent loss values have suddenly been large,
...@@ -816,7 +819,10 @@ namespace dlib ...@@ -816,7 +819,10 @@ namespace dlib
// optimization has flattened out, so drop the learning rate. // optimization has flattened out, so drop the learning rate.
learning_rate = learning_rate_shrink*learning_rate; learning_rate = learning_rate_shrink*learning_rate;
steps_without_progress = 0; steps_without_progress = 0;
previous_loss_values.clear(); // Empty out some of the previous loss values so that steps_without_progress
// will decrease below iter_without_progress_thresh.
for (int cnt = 0; cnt < previous_loss_values_dump_amount && previous_loss_values.size() > 0; ++cnt)
previous_loss_values.pop_front();
} }
} }
} }
...@@ -873,6 +879,7 @@ namespace dlib ...@@ -873,6 +879,7 @@ namespace dlib
prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value; prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value;
updated_net_since_last_sync = false; updated_net_since_last_sync = false;
sync_file_reloaded = false; sync_file_reloaded = false;
previous_loss_values_dump_amount = 100;
start(); start();
} }
...@@ -883,7 +890,7 @@ namespace dlib ...@@ -883,7 +890,7 @@ namespace dlib
friend void serialize(const dnn_trainer& item, std::ostream& out) friend void serialize(const dnn_trainer& item, std::ostream& out)
{ {
item.wait_for_thread_to_pause(); item.wait_for_thread_to_pause();
int version = 9; int version = 10;
serialize(version, out); serialize(version, out);
size_t nl = dnn_trainer::num_layers; size_t nl = dnn_trainer::num_layers;
...@@ -909,6 +916,7 @@ namespace dlib ...@@ -909,6 +916,7 @@ namespace dlib
serialize(item.test_iter_without_progress_thresh.load(), out); serialize(item.test_iter_without_progress_thresh.load(), out);
serialize(item.test_steps_without_progress.load(), out); serialize(item.test_steps_without_progress.load(), out);
serialize(item.test_previous_loss_values, out); serialize(item.test_previous_loss_values, out);
serialize(item.previous_loss_values_dump_amount, out);
} }
friend void deserialize(dnn_trainer& item, std::istream& in) friend void deserialize(dnn_trainer& item, std::istream& in)
...@@ -916,7 +924,7 @@ namespace dlib ...@@ -916,7 +924,7 @@ namespace dlib
item.wait_for_thread_to_pause(); item.wait_for_thread_to_pause();
int version = 0; int version = 0;
deserialize(version, in); deserialize(version, in);
if (version != 9) if (version != 10)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer."); throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
size_t num_layers = 0; size_t num_layers = 0;
...@@ -952,6 +960,7 @@ namespace dlib ...@@ -952,6 +960,7 @@ namespace dlib
deserialize(ltemp, in); item.test_iter_without_progress_thresh = ltemp; deserialize(ltemp, in); item.test_iter_without_progress_thresh = ltemp;
deserialize(ltemp, in); item.test_steps_without_progress = ltemp; deserialize(ltemp, in); item.test_steps_without_progress = ltemp;
deserialize(item.test_previous_loss_values, in); deserialize(item.test_previous_loss_values, in);
deserialize(item.previous_loss_values_dump_amount, in);
if (item.devices.size() > 1) if (item.devices.size() > 1)
{ {
...@@ -1259,6 +1268,7 @@ namespace dlib ...@@ -1259,6 +1268,7 @@ namespace dlib
std::atomic<bool> updated_net_since_last_sync; std::atomic<bool> updated_net_since_last_sync;
bool sync_file_reloaded; bool sync_file_reloaded;
int previous_loss_values_dump_amount;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment