Commit 6a1e50a4 authored by Davis King's avatar Davis King

setup cuDNN forward convolution

parent 10480cb9
......@@ -182,7 +182,10 @@ namespace dlib
out_num_samples(0),
out_k(0),
out_nr(0),
out_nc(0)
out_nc(0),
forward_algo(0),
forward_workspace_size_in_bytes(0),
forward_workspace(nullptr)
{
}
......@@ -200,19 +203,30 @@ namespace dlib
out_k = 0;
out_nr = 0;
out_nc = 0;
if (forward_workspace)
cudaFree(forward_workspace);
forward_workspace = nullptr;
forward_algo = 0;
forward_workspace_size_in_bytes = 0;
}
void conv::
setup(
const tensor& data,
const tensor& filters,
int stride_y,
int stride_x
int stride_y_,
int stride_x_
)
{
DLIB_CASSERT(data.k() == filters.k(),"");
clear();
try
{
stride_y = stride_y_;
stride_x = stride_x_;
check(cudnnCreateFilterDescriptor((cudnnFilterDescriptor_t*)&filter_handle));
check(cudnnSetFilter4dDescriptor((cudnnFilterDescriptor_t)filter_handle,
CUDNN_DATA_FLOAT,
......@@ -238,6 +252,33 @@ namespace dlib
&out_k,
&out_nr,
&out_nc));
tensor_descriptor dest_desc;
dest_desc.set_size(out_num_samples,out_nr,out_nc,out_k);
cudnnConvolutionFwdAlgo_t forward_best_algo;
check(cudnnGetConvolutionForwardAlgorithm(
context(),
(const cudnnTensorDescriptor_t)data.get_cudnn_tensor_descriptor().get_handle(),
(const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnTensorDescriptor_t)dest_desc.get_handle(),
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, // or CUDNN_CONVOLUTION_FWD_NO_WORKSPACE,
std::numeric_limits<size_t>::max(),
&forward_best_algo));
forward_algo = forward_best_algo;
check(cudnnGetConvolutionForwardWorkspaceSize(
context(),
(const cudnnTensorDescriptor_t)data.get_cudnn_tensor_descriptor().get_handle(),
(const cudnnFilterDescriptor_t)filter_handle,
(const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnTensorDescriptor_t)dest_desc.get_handle(),
forward_best_algo,
&forward_workspace_size_in_bytes));
CHECK_CUDA(cudaMalloc(&forward_workspace, forward_workspace_size_in_bytes));
}
catch(...)
{
......@@ -258,6 +299,31 @@ namespace dlib
const tensor& filters
)
{
output.set_size(out_num_samples, out_nr, out_nc, out_k);
// TODO, remove
DLIB_CASSERT(output.num_samples() == data.num_samples(),out_num_samples << " " << data.num_samples());
DLIB_CASSERT(output.k() == filters.num_samples(),"");
DLIB_CASSERT(output.nr() == 1+(data.nr()-1)/stride_y,"");
DLIB_CASSERT(output.nc() == 1+(data.nc()-1)/stride_x,"");
const float alpha = 1;
const float beta = 0;
check(cudnnConvolutionForward(
context(),
&alpha,
(const cudnnTensorDescriptor_t)data.get_cudnn_tensor_descriptor().get_handle(),
data.device(),
(const cudnnFilterDescriptor_t)filter_handle,
filters.device(),
(const cudnnConvolutionDescriptor_t)conv_handle,
(cudnnConvolutionFwdAlgo_t)forward_algo,
forward_workspace,
forward_workspace_size_in_bytes,
&beta,
(const cudnnTensorDescriptor_t)output.get_cudnn_tensor_descriptor().get_handle(),
output.device()));
}
void conv::get_gradient_for_data (
......
......@@ -132,6 +132,10 @@ namespace dlib
int stride_y,
int stride_x
);
/*!
requires
- filters.k() == data.k()
!*/
~conv (
);
......@@ -196,12 +200,18 @@ namespace dlib
private:
void* filter_handle;
void* conv_handle;
int stride_y;
int stride_x;
// dimensions of the output tensor from operator()
int out_num_samples;
int out_nr;
int out_nc;
int out_k;
int forward_algo;
size_t forward_workspace_size_in_bytes;
void* forward_workspace;
};
// ------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment