Commit 7ae43ae2 authored by Davis King's avatar Davis King

Fixed some resource leaks. Also fixed max_pool so it does exactly what the

spec says it should.
parent cbd57be6
...@@ -480,6 +480,7 @@ namespace dlib ...@@ -480,6 +480,7 @@ namespace dlib
catch(...) catch(...)
{ {
clear(); clear();
throw;
} }
} }
...@@ -581,7 +582,7 @@ namespace dlib ...@@ -581,7 +582,7 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
max_pool::max_pool ( max_pool::max_pool (
) : handle(nullptr),stride_y(0),stride_x(0) ) : handle(nullptr),window_height(0),window_width(0),stride_y(0),stride_x(0)
{ {
} }
...@@ -598,31 +599,52 @@ namespace dlib ...@@ -598,31 +599,52 @@ namespace dlib
if (handle) if (handle)
cudnnDestroyPoolingDescriptor((cudnnPoolingDescriptor_t)handle); cudnnDestroyPoolingDescriptor((cudnnPoolingDescriptor_t)handle);
handle = nullptr; handle = nullptr;
window_height = 0;
window_width = 0;
stride_y = 0; stride_y = 0;
stride_x = 0; stride_x = 0;
} }
void max_pool:: void max_pool::
setup( setup(
int window_height, int window_height_,
int window_width, int window_width_,
int stride_y_, int stride_y_,
int stride_x_ int stride_x_
) )
{ {
stride_x = stride_x_; if (window_height == window_height_ &&
stride_y = stride_y_; window_width == window_width_ &&
cudnnPoolingDescriptor_t poolingDesc; stride_y == stride_y_ &&
CHECK_CUDNN(cudnnCreatePoolingDescriptor(&poolingDesc)); stride_x == stride_x_ )
handle = poolingDesc; {
return;
}
CHECK_CUDNN(cudnnSetPooling2dDescriptor(poolingDesc, clear();
CUDNN_POOLING_MAX, try
window_height, {
window_width, window_height = window_height_;
0,0, // no padding window_width = window_width_;
stride_y, stride_x = stride_x_;
stride_x)); stride_y = stride_y_;
cudnnPoolingDescriptor_t poolingDesc;
CHECK_CUDNN(cudnnCreatePoolingDescriptor(&poolingDesc));
handle = poolingDesc;
CHECK_CUDNN(cudnnSetPooling2dDescriptor(poolingDesc,
CUDNN_POOLING_MAX,
window_height,
window_width,
0,0, // no padding
stride_y,
stride_x));
}
catch(...)
{
clear();
throw;
}
} }
void max_pool:: void max_pool::
...@@ -649,8 +671,8 @@ namespace dlib ...@@ -649,8 +671,8 @@ namespace dlib
DLIB_CASSERT(dest.num_samples() == src.num_samples(),""); DLIB_CASSERT(dest.num_samples() == src.num_samples(),"");
DLIB_CASSERT(dest.k() == src.k(),""); DLIB_CASSERT(dest.k() == src.k(),"");
DLIB_CASSERT(dest.nr() == src.nr()/stride_y,""); DLIB_CASSERT(dest.nr() == src.nr()/stride_y, stride_y << ", " << dest.nr() << " " << src.nr()/stride_y);
DLIB_CASSERT(dest.nc() == src.nc()/stride_x,""); DLIB_CASSERT(dest.nc() == src.nc()/stride_x, stride_x << ", " << dest.nc() << " " << src.nc()/stride_x);
CHECK_CUDNN(cudnnPoolingForward(context(), CHECK_CUDNN(cudnnPoolingForward(context(),
(const cudnnPoolingDescriptor_t)handle, (const cudnnPoolingDescriptor_t)handle,
...@@ -673,7 +695,7 @@ namespace dlib ...@@ -673,7 +695,7 @@ namespace dlib
DLIB_CASSERT(have_same_dimensions(src,grad),""); DLIB_CASSERT(have_same_dimensions(src,grad),"");
const float alpha = 1; const float alpha = 1;
const float beta = 0; const float beta = 1;
CHECK_CUDNN(cudnnPoolingBackward(context(), CHECK_CUDNN(cudnnPoolingBackward(context(),
(const cudnnPoolingDescriptor_t)handle, (const cudnnPoolingDescriptor_t)handle,
&alpha, &alpha,
......
...@@ -328,6 +328,8 @@ namespace dlib ...@@ -328,6 +328,8 @@ namespace dlib
private: private:
void* handle; void* handle;
int window_height;
int window_width;
int stride_y; int stride_y;
int stride_x; int stride_x;
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment