Commit 37e422dc authored by Davis King's avatar Davis King

Finished the cuDNN convolution binding.

parent 88f5d9a3
...@@ -166,6 +166,12 @@ namespace dlib ...@@ -166,6 +166,12 @@ namespace dlib
const tensor& src const tensor& src
) )
{ {
DLIB_CASSERT(
(dest.num_samples()==src.num_samples() || src.num_samples()==1) &&
(dest.nr()==src.nr() || src.nr()==1) &&
(dest.nc()==src.nc() || src.nc()==1) &&
(dest.k()==src.k() || src.k()==1), "");
check(cudnnAddTensor_v3(context(), check(cudnnAddTensor_v3(context(),
&alpha, &alpha,
descriptor(src), descriptor(src),
...@@ -215,7 +221,13 @@ namespace dlib ...@@ -215,7 +221,13 @@ namespace dlib
out_nc(0), out_nc(0),
forward_algo(0), forward_algo(0),
forward_workspace_size_in_bytes(0), forward_workspace_size_in_bytes(0),
forward_workspace(nullptr) forward_workspace(nullptr),
backward_data_algo(0),
backward_data_workspace_size_in_bytes(0),
backward_data_workspace(nullptr),
backward_filters_algo(0),
backward_filters_workspace_size_in_bytes(0),
backward_filters_workspace(nullptr)
{ {
} }
...@@ -237,9 +249,20 @@ namespace dlib ...@@ -237,9 +249,20 @@ namespace dlib
if (forward_workspace) if (forward_workspace)
cudaFree(forward_workspace); cudaFree(forward_workspace);
forward_workspace = nullptr; forward_workspace = nullptr;
forward_algo = 0; forward_algo = 0;
forward_workspace_size_in_bytes = 0; forward_workspace_size_in_bytes = 0;
if (backward_data_workspace)
cudaFree(backward_data_workspace);
backward_data_workspace = nullptr;
backward_data_algo = 0;
backward_data_workspace_size_in_bytes = 0;
if (backward_filters_workspace)
cudaFree(backward_filters_workspace);
backward_filters_workspace = nullptr;
backward_filters_algo = 0;
backward_filters_workspace_size_in_bytes = 0;
} }
void conv:: void conv::
...@@ -286,6 +309,8 @@ namespace dlib ...@@ -286,6 +309,8 @@ namespace dlib
tensor_descriptor dest_desc; tensor_descriptor dest_desc;
dest_desc.set_size(out_num_samples,out_k,out_nr,out_nc); dest_desc.set_size(out_num_samples,out_k,out_nr,out_nc);
// Pick which forward algorithm we will use and allocate the necessary
// workspace buffer.
cudnnConvolutionFwdAlgo_t forward_best_algo; cudnnConvolutionFwdAlgo_t forward_best_algo;
check(cudnnGetConvolutionForwardAlgorithm( check(cudnnGetConvolutionForwardAlgorithm(
context(), context(),
...@@ -297,8 +322,6 @@ namespace dlib ...@@ -297,8 +322,6 @@ namespace dlib
std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max(),
&forward_best_algo)); &forward_best_algo));
forward_algo = forward_best_algo; forward_algo = forward_best_algo;
check(cudnnGetConvolutionForwardWorkspaceSize( check(cudnnGetConvolutionForwardWorkspaceSize(
context(), context(),
descriptor(data), descriptor(data),
...@@ -307,8 +330,55 @@ namespace dlib ...@@ -307,8 +330,55 @@ namespace dlib
descriptor(dest_desc), descriptor(dest_desc),
forward_best_algo, forward_best_algo,
&forward_workspace_size_in_bytes)); &forward_workspace_size_in_bytes));
CHECK_CUDA(cudaMalloc(&forward_workspace, forward_workspace_size_in_bytes)); CHECK_CUDA(cudaMalloc(&forward_workspace, forward_workspace_size_in_bytes));
// Pick which backward data algorithm we will use and allocate the
// necessary workspace buffer.
cudnnConvolutionBwdDataAlgo_t backward_data_best_algo;
check(cudnnGetConvolutionBackwardDataAlgorithm(
context(),
(const cudnnFilterDescriptor_t)filter_handle,
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(data),
CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
std::numeric_limits<size_t>::max(),
&backward_data_best_algo));
backward_data_algo = backward_data_best_algo;
check(cudnnGetConvolutionBackwardDataWorkspaceSize(
context(),
(const cudnnFilterDescriptor_t)filter_handle,
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
descriptor(data),
backward_data_best_algo,
&backward_data_workspace_size_in_bytes));
CHECK_CUDA(cudaMalloc(&backward_data_workspace, backward_data_workspace_size_in_bytes));
// Pick which backward filters algorithm we will use and allocate the
// necessary workspace buffer.
cudnnConvolutionBwdFilterAlgo_t backward_filters_best_algo;
check(cudnnGetConvolutionBackwardFilterAlgorithm(
context(),
descriptor(data),
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnFilterDescriptor_t)filter_handle,
CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
std::numeric_limits<size_t>::max(),
&backward_filters_best_algo));
backward_filters_algo = backward_filters_best_algo;
check(cudnnGetConvolutionBackwardFilterWorkspaceSize(
context(),
descriptor(data),
descriptor(dest_desc),
(const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnFilterDescriptor_t)filter_handle,
backward_filters_best_algo,
&backward_filters_workspace_size_in_bytes));
CHECK_CUDA(cudaMalloc(&backward_filters_workspace, backward_filters_workspace_size_in_bytes));
} }
catch(...) catch(...)
{ {
...@@ -362,6 +432,23 @@ namespace dlib ...@@ -362,6 +432,23 @@ namespace dlib
tensor& data_gradient tensor& data_gradient
) )
{ {
const float alpha = 1;
const float beta = 1;
check(cudnnConvolutionBackwardData_v3(context(),
&alpha,
(const cudnnFilterDescriptor_t)filter_handle,
filters.device(),
descriptor(gradient_input),
gradient_input.device(),
(const cudnnConvolutionDescriptor_t)conv_handle,
(cudnnConvolutionBwdDataAlgo_t)backward_data_algo,
backward_data_workspace,
backward_data_workspace_size_in_bytes,
&beta,
descriptor(data_gradient),
data_gradient.device()));
} }
void conv:: void conv::
...@@ -371,6 +458,21 @@ namespace dlib ...@@ -371,6 +458,21 @@ namespace dlib
tensor& filters_gradient tensor& filters_gradient
) )
{ {
const float alpha = 1;
const float beta = 1;
check(cudnnConvolutionBackwardFilter_v3(context(),
&alpha,
descriptor(data),
data.device(),
descriptor(gradient_input),
gradient_input.device(),
(const cudnnConvolutionDescriptor_t)conv_handle,
(cudnnConvolutionBwdFilterAlgo_t)backward_filters_algo,
backward_filters_workspace,
backward_filters_workspace_size_in_bytes,
&beta,
(const cudnnFilterDescriptor_t)filter_handle,
filters_gradient.device()));
} }
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
......
...@@ -63,10 +63,6 @@ namespace dlib ...@@ -63,10 +63,6 @@ namespace dlib
void* handle; void* handle;
}; };
// ------------------------------------------------------------------------------------
// add a call that maps to cudnnConvolutionBackwardBias()
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
void add( void add(
...@@ -111,6 +107,11 @@ namespace dlib ...@@ -111,6 +107,11 @@ namespace dlib
- E = E*value - E = E*value
!*/ !*/
// ------------------------------------------------------------------------------------
// TODO
// add a call that maps to cudnnConvolutionBackwardBias()
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
class conv class conv
...@@ -145,8 +146,8 @@ namespace dlib ...@@ -145,8 +146,8 @@ namespace dlib
); );
/*! /*!
requires requires
- the dimensions of data and filters are the same as the ones given - The dimensions of data and filters are the same as the ones given
to the constructor. to the last call to setup().
ensures ensures
- convolves filters over data. - convolves filters over data.
- filters contains filters.num_samples() filters. - filters contains filters.num_samples() filters.
...@@ -156,7 +157,6 @@ namespace dlib ...@@ -156,7 +157,6 @@ namespace dlib
- #output.nc() == 1+(data.nc()-1)/stride_x - #output.nc() == 1+(data.nc()-1)/stride_x
!*/ !*/
// get gradient of data: 4.49. cudnnConvolutionBackwardData_v3
void get_gradient_for_data ( void get_gradient_for_data (
const tensor& gradient_input, const tensor& gradient_input,
const tensor& filters, const tensor& filters,
...@@ -165,9 +165,9 @@ namespace dlib ...@@ -165,9 +165,9 @@ namespace dlib
/*! /*!
requires requires
- filters has the same dimensions as the filters object give to the - filters has the same dimensions as the filters object give to the
constructor. last call to setup().
- data_gradient has the same dimensions as the data object give to the - data_gradient has the same dimensions as the data object give to the
constructor. last call to setup().
- gradient_input has the same dimensions as the output of operator(). - gradient_input has the same dimensions as the output of operator().
ensures ensures
- let OUT be the output of (*this)(OUT,data,filters). - let OUT be the output of (*this)(OUT,data,filters).
...@@ -176,7 +176,6 @@ namespace dlib ...@@ -176,7 +176,6 @@ namespace dlib
and adds this gradient to data_gradient. and adds this gradient to data_gradient.
!*/ !*/
// get gradient of filters: 4.44. cudnnConvolutionBackwardFilter_v3
void get_gradient_for_filters ( void get_gradient_for_filters (
const tensor& gradient_input, const tensor& gradient_input,
const tensor& data, const tensor& data,
...@@ -185,8 +184,9 @@ namespace dlib ...@@ -185,8 +184,9 @@ namespace dlib
/*! /*!
requires requires
- filters_gradient has the same dimensions as the filters object give - filters_gradient has the same dimensions as the filters object give
to the constructor. to the last call to setup().
- data has the same dimensions as the data object give to the constructor. - data has the same dimensions as the data object give to the last call
to setup().
- gradient_input has the same dimensions as the output of operator(). - gradient_input has the same dimensions as the output of operator().
ensures ensures
- let OUT be the output of (*this)(OUT,data,filters). - let OUT be the output of (*this)(OUT,data,filters).
...@@ -210,6 +210,14 @@ namespace dlib ...@@ -210,6 +210,14 @@ namespace dlib
int forward_algo; int forward_algo;
size_t forward_workspace_size_in_bytes; size_t forward_workspace_size_in_bytes;
void* forward_workspace; void* forward_workspace;
int backward_data_algo;
size_t backward_data_workspace_size_in_bytes;
void* backward_data_workspace;
int backward_filters_algo;
size_t backward_filters_workspace_size_in_bytes;
void* backward_filters_workspace;
}; };
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment