Commit d63c4682 authored by Davis King's avatar Davis King

Added some of the cuDNN conv calls.

parent 1022b5b9
......@@ -9,6 +9,7 @@
#include "tensor.h"
#include <cudnn.h>
#include <iostream>
#include <string>
#include "cuda_utils.h"
......@@ -31,10 +32,11 @@ namespace dlib
case CUDNN_STATUS_BAD_PARAM:
throw cudnn_error("CUDNN_STATUS_BAD_PARAM");
default:
throw cudnn_error("A call to cuDNN failed.");
throw cudnn_error("A call to cuDNN failed: " + std::string(cudnnGetErrorString(s)));
}
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
cudnn_context::cudnn_context() : handle(nullptr)
......@@ -162,14 +164,83 @@ namespace dlib
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
conv::conv(
conv::
conv(
) :
filter_handle(nullptr),
conv_handle(nullptr),
out_num_samples(0),
out_k(0),
out_nr(0),
out_nc(0)
{
}
void conv::
clear (
)
{
if (filter_handle)
cudnnDestroyFilterDescriptor((cudnnFilterDescriptor_t)filter_handle);
if (conv_handle)
cudnnDestroyConvolutionDescriptor((cudnnConvolutionDescriptor_t)conv_handle);
filter_handle = nullptr;
conv_handle = nullptr;
out_num_samples = 0;
out_k = 0;
out_nr = 0;
out_nc = 0;
}
void conv::
setup(
cudnn_context& context,
const tensor& data,
const tensor& filters,
int stride_y,
int stride_x
)
{
clear();
try
{
check(cudnnCreateFilterDescriptor((cudnnFilterDescriptor_t*)&filter_handle));
check(cudnnSetFilter4dDescriptor((cudnnFilterDescriptor_t)filter_handle,
CUDNN_DATA_FLOAT,
filters.num_samples(),
filters.k(),
filters.nr(),
filters.nc()));
check(cudnnCreateConvolutionDescriptor((cudnnConvolutionDescriptor_t*)&conv_handle));
check(cudnnSetConvolution2dDescriptor((cudnnConvolutionDescriptor_t)conv_handle,
filters.nr()/2, // vertical padding
filters.nc()/2, // horizontal padding
stride_y,
stride_x,
1, 1, // must be 1,1
CUDNN_CONVOLUTION)); // could also be CUDNN_CROSS_CORRELATION
check(cudnnGetConvolution2dForwardOutputDim(
(const cudnnConvolutionDescriptor_t)conv_handle,
(const cudnnTensorDescriptor_t)data.get_cudnn_tensor_descriptor().get_handle(),
(const cudnnFilterDescriptor_t)filter_handle,
&out_num_samples,
&out_k,
&out_nr,
&out_nc));
}
catch(...)
{
clear();
}
}
conv::
~conv (
)
{
clear();
}
void conv::operator() (
......
......@@ -66,6 +66,10 @@ namespace dlib
int nc,
int k
);
/*!
ensures
- if any of the arguments are 0 then they are all set to 0 in the tensor.
!*/
void get_size (
int& n,
......@@ -145,7 +149,12 @@ namespace dlib
conv(const conv&) = delete;
conv& operator=(const conv&) = delete;
conv(
conv();
void clear(
);
void setup(
cudnn_context& context,
const tensor& data,
const tensor& filters,
......@@ -153,6 +162,9 @@ namespace dlib
int stride_x
);
~conv (
);
void operator() (
resizable_tensor& output,
const tensor& data,
......@@ -210,6 +222,15 @@ namespace dlib
and adds this gradient to filters_gradient.
!*/
private:
void* filter_handle;
void* conv_handle;
// dimensions of the output tensor from operator()
int out_num_samples;
int out_nr;
int out_nc;
int out_k;
};
// ------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment