Commit ccd8b64f authored by Juha Reunanen's avatar Juha Reunanen Committed by Davis E. King

Semantic-segmentation loss calculation: fix buffer usage on multi-GPU training (#1717)

* Semantic-segmentation loss calculation: fix buffer usage on multi-GPU training

* Review fix: make the work buffer live longer
parent 9433bfd6
...@@ -423,7 +423,6 @@ namespace dlib ...@@ -423,7 +423,6 @@ namespace dlib
compute_loss_multiclass_log_per_pixel( compute_loss_multiclass_log_per_pixel(
) )
{ {
work = device_global_buffer();
} }
template < template <
...@@ -439,6 +438,10 @@ namespace dlib ...@@ -439,6 +438,10 @@ namespace dlib
const size_t bytes_per_plane = subnetwork_output.nr()*subnetwork_output.nc()*sizeof(uint16_t); const size_t bytes_per_plane = subnetwork_output.nr()*subnetwork_output.nc()*sizeof(uint16_t);
// Allocate a cuda buffer to store all the truth images and also one float // Allocate a cuda buffer to store all the truth images and also one float
// for the scalar loss output. // for the scalar loss output.
if (!work)
{
work = device_global_buffer();
}
cuda_data_void_ptr buf = work->get(subnetwork_output.num_samples()*bytes_per_plane + sizeof(float)); cuda_data_void_ptr buf = work->get(subnetwork_output.num_samples()*bytes_per_plane + sizeof(float));
cuda_data_void_ptr loss_buf = buf; cuda_data_void_ptr loss_buf = buf;
...@@ -467,7 +470,7 @@ namespace dlib ...@@ -467,7 +470,7 @@ namespace dlib
double& loss double& loss
); );
std::shared_ptr<resizable_cuda_buffer> work; mutable std::shared_ptr<resizable_cuda_buffer> work;
}; };
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment