Commit 00b2c22c authored by Davis King's avatar Davis King

Implemented cuDNN based max_pool

parent a07b31da
...@@ -503,17 +503,48 @@ namespace dlib ...@@ -503,17 +503,48 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
max_pool::max_pool ( max_pool::max_pool (
int window_height, ) : handle(nullptr),stride_y(0),stride_x(0)
int window_width,
int stride_y,
int stride_x
)
{ {
} }
max_pool::~max_pool( max_pool::~max_pool(
) )
{ {
clear();
}
void max_pool::
clear(
)
{
if (handle)
cudnnDestroyPoolingDescriptor((cudnnPoolingDescriptor_t)handle);
handle = nullptr;
stride_y = 0;
stride_x = 0;
}
void max_pool::
setup(
int window_height,
int window_width,
int stride_y_,
int stride_x_
)
{
stride_x = stride_x_;
stride_y = stride_y_;
cudnnPoolingDescriptor_t poolingDesc;
check(cudnnCreatePoolingDescriptor(&poolingDesc));
handle = poolingDesc;
check(cudnnSetPooling2dDescriptor(poolingDesc,
CUDNN_POOLING_MAX,
window_height,
window_width,
0,0, // no padding
stride_y,
stride_x));
} }
void max_pool:: void max_pool::
...@@ -522,14 +553,61 @@ namespace dlib ...@@ -522,14 +553,61 @@ namespace dlib
const tensor& src const tensor& src
) )
{ {
const float alpha = 1;
const float beta = 0;
int outN;
int outC;
int outH;
int outW;
check(cudnnGetPooling2dForwardOutputDim((const cudnnPoolingDescriptor_t)handle,
descriptor(src),
&outN,
&outC,
&outH,
&outW));
dest.set_size(outN,outC,outH,outW);
DLIB_CASSERT(dest.num_samples() == src.num_samples(),"");
DLIB_CASSERT(dest.k() == src.k(),"");
DLIB_CASSERT(dest.nr() == src.nr()/stride_y,"");
DLIB_CASSERT(dest.nc() == src.nc()/stride_x,"");
check(cudnnPoolingForward(context(),
(const cudnnPoolingDescriptor_t)handle,
&alpha,
descriptor(src),
src.device(),
&beta,
descriptor(dest),
dest.device()));
} }
void max_pool::get_gradient( void max_pool::get_gradient(
const tensor& gradient_input, const tensor& gradient_input,
const tensor& dest,
const tensor& src, const tensor& src,
tensor& grad tensor& grad
) )
{ {
DLIB_CASSERT(have_same_dimensions(gradient_input,dest),"");
DLIB_CASSERT(have_same_dimensions(src,grad),"");
const float alpha = 1;
const float beta = 0;
check(cudnnPoolingBackward(context(),
(const cudnnPoolingDescriptor_t)handle,
&alpha,
descriptor(dest),
dest.device(),
descriptor(gradient_input),
gradient_input.device(),
descriptor(src),
src.device(),
&beta,
descriptor(grad),
grad.device()));
} }
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
......
...@@ -244,45 +244,68 @@ namespace dlib ...@@ -244,45 +244,68 @@ namespace dlib
class max_pool class max_pool
{ {
/*! /*!
CUDNN_POOLING_MAX
!*/ !*/
public: public:
max_pool(const max_pool&) = delete; max_pool(const max_pool&) = delete;
max_pool& operator=(const max_pool&) = delete; max_pool& operator=(const max_pool&) = delete;
// cudnnCreatePoolingDescriptor(), cudnnSetPooling2dDescriptor()
max_pool ( max_pool (
);
~max_pool(
);
void clear(
);
void setup(
int window_height, int window_height,
int window_width, int window_width,
int stride_y, int stride_y,
int stride_x int stride_x
); );
// cudnnDestroyPoolingDescriptor ()
~max_pool(
);
// cudnnGetPooling2dForwardOutputDim(), cudnnPoolingForward()
void operator() ( void operator() (
resizable_tensor& dest, resizable_tensor& dest,
const tensor& src const tensor& src
); );
/*! /*!
ensures
- #dest.num_samples() == src.num_samples()
- #dest.k() == src.k()
- #dest.nr() == src.nr()/stride_y
- #dest.nc() == src.nc()/stride_x
- for all valid s, k, r, and c:
- image_plane(#dest,s,k)(r,c) == max(subm_clipped(image_plane(src,s,k),
r*stride_y,
c*stride_x,
window_height,
window_width))
!*/ !*/
// cudnnPoolingBackward()
void get_gradient( void get_gradient(
const tensor& gradient_input, const tensor& gradient_input,
const tensor& dest,
const tensor& src, const tensor& src,
tensor& grad tensor& grad
); );
/*! /*!
- let OUT be the output of (*this)(OUT,src) requires
- let f(src) == dot(gradient_input,OUT) - have_same_dimensions(gradient_input,dest) == true
- Then this function computes the gradient of f() with respect to src and - have_same_dimensions(src,grad) == true
adds it to grad. - dest contains the result of calling (*this)(dest,src)
ensures
- Recalling that dest is the output of (*this)(dest,src),
let f(src) == dot(gradient_input,dest)
- Then this function computes the gradient of f() with respect to src
and adds it to grad.
!*/ !*/
private:
void* handle;
int stride_y;
int stride_x;
}; };
// TODO, make the order of parameters of all these functions consistent. // TODO, make the order of parameters of all these functions consistent.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment