Commit 6a1e50a4 authored by Davis King's avatar Davis King

setup cuDNN forward convolution

parent 10480cb9
...@@ -182,7 +182,10 @@ namespace dlib ...@@ -182,7 +182,10 @@ namespace dlib
out_num_samples(0), out_num_samples(0),
out_k(0), out_k(0),
out_nr(0), out_nr(0),
out_nc(0) out_nc(0),
forward_algo(0),
forward_workspace_size_in_bytes(0),
forward_workspace(nullptr)
{ {
} }
...@@ -200,19 +203,30 @@ namespace dlib ...@@ -200,19 +203,30 @@ namespace dlib
out_k = 0; out_k = 0;
out_nr = 0; out_nr = 0;
out_nc = 0; out_nc = 0;
if (forward_workspace)
cudaFree(forward_workspace);
forward_workspace = nullptr;
forward_algo = 0;
forward_workspace_size_in_bytes = 0;
} }
void conv:: void conv::
setup( setup(
const tensor& data, const tensor& data,
const tensor& filters, const tensor& filters,
int stride_y, int stride_y_,
int stride_x int stride_x_
) )
{ {
DLIB_CASSERT(data.k() == filters.k(),"");
clear(); clear();
try try
{ {
stride_y = stride_y_;
stride_x = stride_x_;
check(cudnnCreateFilterDescriptor((cudnnFilterDescriptor_t*)&filter_handle)); check(cudnnCreateFilterDescriptor((cudnnFilterDescriptor_t*)&filter_handle));
check(cudnnSetFilter4dDescriptor((cudnnFilterDescriptor_t)filter_handle, check(cudnnSetFilter4dDescriptor((cudnnFilterDescriptor_t)filter_handle,
CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT,
...@@ -238,6 +252,33 @@ namespace dlib ...@@ -238,6 +252,33 @@ namespace dlib
&out_k, &out_k,
&out_nr, &out_nr,
&out_nc)); &out_nc));
tensor_descriptor dest_desc;
dest_desc.set_size(out_num_samples,out_nr,out_nc,out_k);
cudnnConvolutionFwdAlgo_t forward_best_algo;
check(cudnnGetConvolutionForwardAlgorithm(
context(),
(const cudnnTensorDescriptor_t)data.get_cudnn_tensor_descriptor().get_handle(),
(const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnTensorDescriptor_t)dest_desc.get_handle(),
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, // or CUDNN_CONVOLUTION_FWD_NO_WORKSPACE,
std::numeric_limits<size_t>::max(),
&forward_best_algo));
forward_algo = forward_best_algo;
check(cudnnGetConvolutionForwardWorkspaceSize(
context(),
(const cudnnTensorDescriptor_t)data.get_cudnn_tensor_descriptor().get_handle(),
(const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnTensorDescriptor_t)dest_desc.get_handle(),
forward_best_algo,
&forward_workspace_size_in_bytes));
CHECK_CUDA(cudaMalloc(&forward_workspace, forward_workspace_size_in_bytes));
} }
catch(...) catch(...)
{ {
...@@ -258,6 +299,31 @@ namespace dlib ...@@ -258,6 +299,31 @@ namespace dlib
const tensor& filters const tensor& filters
) )
{ {
output.set_size(out_num_samples, out_nr, out_nc, out_k);
// TODO, remove
DLIB_CASSERT(output.num_samples() == data.num_samples(),out_num_samples << " " << data.num_samples());
DLIB_CASSERT(output.k() == filters.num_samples(),"");
DLIB_CASSERT(output.nr() == 1+(data.nr()-1)/stride_y,"");
DLIB_CASSERT(output.nc() == 1+(data.nc()-1)/stride_x,"");
const float alpha = 1;
const float beta = 0;
check(cudnnConvolutionForward(
context(),
&alpha,
(const cudnnTensorDescriptor_t)data.get_cudnn_tensor_descriptor().get_handle(),
data.device(),
(const cudnnFilterDescriptor_t)filter_handle,
filters.device(),
(const cudnnConvolutionDescriptor_t)conv_handle,
(cudnnConvolutionFwdAlgo_t)forward_algo,
forward_workspace,
forward_workspace_size_in_bytes,
&beta,
(const cudnnTensorDescriptor_t)output.get_cudnn_tensor_descriptor().get_handle(),
output.device()));
} }
void conv::get_gradient_for_data ( void conv::get_gradient_for_data (
......
...@@ -132,6 +132,10 @@ namespace dlib ...@@ -132,6 +132,10 @@ namespace dlib
int stride_y, int stride_y,
int stride_x int stride_x
); );
/*!
requires
- filters.k() == data.k()
!*/
~conv ( ~conv (
); );
...@@ -196,12 +200,18 @@ namespace dlib ...@@ -196,12 +200,18 @@ namespace dlib
private: private:
void* filter_handle; void* filter_handle;
void* conv_handle; void* conv_handle;
int stride_y;
int stride_x;
// dimensions of the output tensor from operator() // dimensions of the output tensor from operator()
int out_num_samples; int out_num_samples;
int out_nr; int out_nr;
int out_nc; int out_nc;
int out_k; int out_k;
int forward_algo;
size_t forward_workspace_size_in_bytes;
void* forward_workspace;
}; };
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment