Commit 362bec10 authored by Plumtus's avatar Plumtus Committed by Davis E. King

Reinitialize averagers when saved sync file was reloaded. (#629)

parent d2b80bfe
......@@ -713,7 +713,7 @@ namespace dlib
// We can't do this outside the loop because the tensors that get
// averaged need to be allocated to their devices before we call set()
// so that the averagers can determine how best to average them.
if (averagers.size() == 0)
if (averagers.size() == 0 || sync_file_reloaded)
{
averagers = std::vector<tt::multi_device_tensor_averager>(net_type::num_computational_layers);
// setup the averagers to point to the tensors in the networks.
......@@ -736,6 +736,8 @@ namespace dlib
if (temp[0]->size() != 0)
averagers[i].set(temp);
}
sync_file_reloaded = false;
}
......@@ -855,6 +857,7 @@ namespace dlib
prob_loss_increasing_thresh_max_value = 0.99999;
prob_loss_increasing_thresh = prob_loss_increasing_thresh_default_value;
updated_net_since_last_sync = false;
sync_file_reloaded = false;
start();
}
......@@ -979,6 +982,7 @@ namespace dlib
{
std::ifstream fin(sync_filename, std::ios::binary);
deserialize(*this, fin);
sync_file_reloaded = true;
if (verbose)
std::cout << "Loss has been increasing, reloading saved state from " << sync_filename << std::endl;
}
......@@ -1230,6 +1234,7 @@ namespace dlib
double prob_loss_increasing_thresh;
std::atomic<bool> updated_net_since_last_sync;
bool sync_file_reloaded;
};
// ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment