Commit 1f5335c1 authored by Davis King's avatar Davis King

Made tensor_conv hold references to the cuda_data_void_ptr work buffers in the

member area of the class.  This way, we avoid a potential error where the
buffers are reallocated while cuDNN is still using them in the background.
parent 863702f0
......@@ -795,6 +795,9 @@ namespace dlib
backward_data_workspace_size_in_bytes = 0;
backward_filters_workspace_size_in_bytes = 0;
forward_workspace.reset();
backward_data_workspace.reset();
backward_filters_workspace.reset();
workspace.reset();
}
......@@ -1030,6 +1033,13 @@ namespace dlib
const float alpha = 1;
const float beta = add_to_output ? 1 : 0;
// Since cudnnConvolutionForward() is an asynchronous call, we need to hold a
// reference to the workspace buffer so we can be sure it isn't reallocated
// while the function is still executing on the device. But each time we come
// here, we make sure to grab the latest workspace buffer so that, globally, we
// minimize the number of such buffers.
forward_workspace = workspace->get(forward_workspace_size_in_bytes);
CHECK_CUDNN(cudnnConvolutionForward(
context(),
&alpha,
......@@ -1039,7 +1049,7 @@ namespace dlib
filters.device(),
(const cudnnConvolutionDescriptor_t)conv_handle,
(cudnnConvolutionFwdAlgo_t)forward_algo,
workspace->get(forward_workspace_size_in_bytes),
forward_workspace,
forward_workspace_size_in_bytes,
&beta,
descriptor(output),
......@@ -1056,6 +1066,13 @@ namespace dlib
const float alpha = 1;
const float beta = add_to_output ? 1 : 0;
// Since cudnnConvolutionBackwardData() is an asynchronous call, we need to hold a
// reference to the workspace buffer so we can be sure it isn't reallocated
// while the function is still executing on the device. But each time we come
// here, we make sure to grab the latest workspace buffer so that, globally, we
// minimize the number of such buffers.
backward_data_workspace = workspace->get(backward_data_workspace_size_in_bytes);
CHECK_CUDNN(cudnnConvolutionBackwardData(context(),
&alpha,
......@@ -1065,7 +1082,7 @@ namespace dlib
gradient_input.device(),
(const cudnnConvolutionDescriptor_t)conv_handle,
(cudnnConvolutionBwdDataAlgo_t)backward_data_algo,
workspace->get(backward_data_workspace_size_in_bytes),
backward_data_workspace,
backward_data_workspace_size_in_bytes,
&beta,
descriptor(data_gradient),
......@@ -1082,6 +1099,14 @@ namespace dlib
{
const float alpha = 1;
const float beta = add_to_output ? 1 : 0;
// Since cudnnConvolutionBackwardFilter() is an asynchronous call, we need to hold a
// reference to the workspace buffer so we can be sure it isn't reallocated
// while the function is still executing on the device. But each time we come
// here, we make sure to grab the latest workspace buffer so that, globally, we
// minimize the number of such buffers.
backward_filters_workspace = workspace->get(backward_filters_workspace_size_in_bytes);
CHECK_CUDNN(cudnnConvolutionBackwardFilter(context(),
&alpha,
descriptor(data),
......@@ -1090,7 +1115,7 @@ namespace dlib
gradient_input.device(),
(const cudnnConvolutionDescriptor_t)conv_handle,
(cudnnConvolutionBwdFilterAlgo_t)backward_filters_algo,
workspace->get(backward_filters_workspace_size_in_bytes),
backward_filters_workspace,
backward_filters_workspace_size_in_bytes,
&beta,
(const cudnnFilterDescriptor_t)filter_handle,
......
......@@ -269,6 +269,9 @@ namespace dlib
size_t backward_data_workspace_size_in_bytes;
size_t backward_filters_workspace_size_in_bytes;
std::shared_ptr<resizable_cuda_buffer> workspace;
cuda_data_void_ptr forward_workspace;
cuda_data_void_ptr backward_data_workspace;
cuda_data_void_ptr backward_filters_workspace;
};
// ------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment