Commit 1db46949 authored by Davis King's avatar Davis King

Made the test loss in the verbose output messages from the dnn_trainer not jump

in variance when the learning rate resets.
parent d009916e
...@@ -433,10 +433,7 @@ namespace dlib ...@@ -433,10 +433,7 @@ namespace dlib
) const ) const
{ {
wait_for_thread_to_pause(); wait_for_thread_to_pause();
running_stats<double> tmp; return rs_test.mean();
for (auto& x : test_previous_loss_values)
tmp.add(x);
return tmp.mean();
} }
void clear_average_loss ( void clear_average_loss (
...@@ -582,6 +579,7 @@ namespace dlib ...@@ -582,6 +579,7 @@ namespace dlib
void record_test_loss(double loss) void record_test_loss(double loss)
{ {
test_previous_loss_values.push_back(loss); test_previous_loss_values.push_back(loss);
rs_test.add(loss);
// discard really old loss values. // discard really old loss values.
while (test_previous_loss_values.size() > test_iter_without_progress_thresh) while (test_previous_loss_values.size() > test_iter_without_progress_thresh)
test_previous_loss_values.pop_front(); test_previous_loss_values.pop_front();
...@@ -881,6 +879,9 @@ namespace dlib ...@@ -881,6 +879,9 @@ namespace dlib
sync_file_reloaded = false; sync_file_reloaded = false;
previous_loss_values_dump_amount = 400; previous_loss_values_dump_amount = 400;
test_previous_loss_values_dump_amount = 100; test_previous_loss_values_dump_amount = 100;
rs_test = running_stats_decayed<double>(200);
start(); start();
} }
...@@ -891,12 +892,13 @@ namespace dlib ...@@ -891,12 +892,13 @@ namespace dlib
friend void serialize(const dnn_trainer& item, std::ostream& out) friend void serialize(const dnn_trainer& item, std::ostream& out)
{ {
item.wait_for_thread_to_pause(); item.wait_for_thread_to_pause();
int version = 11; int version = 12;
serialize(version, out); serialize(version, out);
size_t nl = dnn_trainer::num_layers; size_t nl = dnn_trainer::num_layers;
serialize(nl, out); serialize(nl, out);
serialize(item.rs, out); serialize(item.rs, out);
serialize(item.rs_test, out);
serialize(item.previous_loss_values, out); serialize(item.previous_loss_values, out);
serialize(item.max_num_epochs, out); serialize(item.max_num_epochs, out);
serialize(item.mini_batch_size, out); serialize(item.mini_batch_size, out);
...@@ -926,7 +928,7 @@ namespace dlib ...@@ -926,7 +928,7 @@ namespace dlib
item.wait_for_thread_to_pause(); item.wait_for_thread_to_pause();
int version = 0; int version = 0;
deserialize(version, in); deserialize(version, in);
if (version != 11) if (version != 12)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer."); throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
size_t num_layers = 0; size_t num_layers = 0;
...@@ -942,6 +944,7 @@ namespace dlib ...@@ -942,6 +944,7 @@ namespace dlib
double dtemp; long ltemp; double dtemp; long ltemp;
deserialize(item.rs, in); deserialize(item.rs, in);
deserialize(item.rs_test, in);
deserialize(item.previous_loss_values, in); deserialize(item.previous_loss_values, in);
deserialize(item.max_num_epochs, in); deserialize(item.max_num_epochs, in);
deserialize(item.mini_batch_size, in); deserialize(item.mini_batch_size, in);
...@@ -1226,6 +1229,7 @@ namespace dlib ...@@ -1226,6 +1229,7 @@ namespace dlib
running_stats<double> rs; running_stats<double> rs;
running_stats_decayed<double> rs_test;
std::deque<double> previous_loss_values; std::deque<double> previous_loss_values;
unsigned long max_num_epochs; unsigned long max_num_epochs;
size_t mini_batch_size; size_t mini_batch_size;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment