Commit 0d6e3f12 authored by Davis King's avatar Davis King

Made the multi_device_tensor_averager not assume the size of the tensors is

known at set() time.
parent b85688ac
......@@ -1067,14 +1067,6 @@ namespace dlib { namespace tt
epa.emplace_back(new enable_peer_access(*g[0], *g[i]));
}
}
// If there are multiple groups then we need to use the accum_buffer space
// when talking across groups. So allocate that buffer now.
if (accessible_groups.size() > 1)
{
raii_set_device set_dev(*accessible_groups[0][0]);
accum_buffer.copy_size(*accessible_groups[0][0]);
}
}
void average()
......@@ -1108,6 +1100,7 @@ namespace dlib { namespace tt
{
tensor& total_avg = *accessible_groups[0][0];
raii_set_device set_dev(total_avg);
accum_buffer.copy_size(total_avg);
// now we need to average things across groups
for (size_t i = 1; i < accessible_groups.size(); ++i)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment