Commit 362bec10 authored by Plumtus's avatar Plumtus Committed by Davis E. King

Reinitialize averagers when saved sync file was reloaded. (#629)

parent d2b80bfe
...@@ -713,7 +713,7 @@ namespace dlib ...@@ -713,7 +713,7 @@ namespace dlib
// We can't do this outside the loop because the tensors that get // We can't do this outside the loop because the tensors that get
// averaged need to be allocated to their devices before we call set() // averaged need to be allocated to their devices before we call set()
// so that the averagers can determine how best to average them. // so that the averagers can determine how best to average them.
if (averagers.size() == 0) if (averagers.size() == 0 || sync_file_reloaded)
{ {
averagers = std::vector<tt::multi_device_tensor_averager>(net_type::num_computational_layers); averagers = std::vector<tt::multi_device_tensor_averager>(net_type::num_computational_layers);
// setup the averagers to point to the tensors in the networks. // setup the averagers to point to the tensors in the networks.
...@@ -736,6 +736,8 @@ namespace dlib ...@@ -736,6 +736,8 @@ namespace dlib
if (temp[0]->size() != 0) if (temp[0]->size() != 0)
averagers[i].set(temp); averagers[i].set(temp);
} }
sync_file_reloaded = false;
} }
...@@ -855,6 +857,7 @@ namespace dlib ...@@ -855,6 +857,7 @@ namespace dlib
prob_loss_increasing_thresh_max_value = 0.99999; prob_loss_increasing_thresh_max_value = 0.99999;
prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value; prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value;
updated_net_since_last_sync = false; updated_net_since_last_sync = false;
sync_file_reloaded = false;
start(); start();
} }
...@@ -979,6 +982,7 @@ namespace dlib ...@@ -979,6 +982,7 @@ namespace dlib
{ {
std::ifstream fin(sync_filename, std::ios::binary); std::ifstream fin(sync_filename, std::ios::binary);
deserialize(*this, fin); deserialize(*this, fin);
sync_file_reloaded = true;
if (verbose) if (verbose)
std::cout << "Loss has been increasing, reloading saved state from " << sync_filename << std::endl; std::cout << "Loss has been increasing, reloading saved state from " << sync_filename << std::endl;
} }
...@@ -1230,6 +1234,7 @@ namespace dlib ...@@ -1230,6 +1234,7 @@ namespace dlib
double prob_loss_increasing_thresh; double prob_loss_increasing_thresh;
std::atomic<bool> updated_net_since_last_sync; std::atomic<bool> updated_net_since_last_sync;
bool sync_file_reloaded;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment