Commit 7ae43ae2 authored by Davis King's avatar Davis King

Fixed some resource leaks. Also fixed max_pool so it does exactly what the

spec says it should.
parent cbd57be6
......@@ -480,6 +480,7 @@ namespace dlib
catch(...)
{
clear();
throw;
}
}
......@@ -581,7 +582,7 @@ namespace dlib
// ------------------------------------------------------------------------------------
max_pool::max_pool (
) : handle(nullptr),stride_y(0),stride_x(0)
) : handle(nullptr),window_height(0),window_width(0),stride_y(0),stride_x(0)
{
}
......@@ -598,18 +599,33 @@ namespace dlib
if (handle)
cudnnDestroyPoolingDescriptor((cudnnPoolingDescriptor_t)handle);
handle = nullptr;
window_height = 0;
window_width = 0;
stride_y = 0;
stride_x = 0;
}
void max_pool::
setup(
int window_height,
int window_width,
int window_height_,
int window_width_,
int stride_y_,
int stride_x_
)
{
if (window_height == window_height_ &&
window_width == window_width_ &&
stride_y == stride_y_ &&
stride_x == stride_x_ )
{
return;
}
clear();
try
{
window_height = window_height_;
window_width = window_width_;
stride_x = stride_x_;
stride_y = stride_y_;
cudnnPoolingDescriptor_t poolingDesc;
......@@ -624,6 +640,12 @@ namespace dlib
stride_y,
stride_x));
}
catch(...)
{
clear();
throw;
}
}
void max_pool::
operator() (
......@@ -649,8 +671,8 @@ namespace dlib
DLIB_CASSERT(dest.num_samples() == src.num_samples(),"");
DLIB_CASSERT(dest.k() == src.k(),"");
DLIB_CASSERT(dest.nr() == src.nr()/stride_y,"");
DLIB_CASSERT(dest.nc() == src.nc()/stride_x,"");
DLIB_CASSERT(dest.nr() == src.nr()/stride_y, stride_y << ", " << dest.nr() << " " << src.nr()/stride_y);
DLIB_CASSERT(dest.nc() == src.nc()/stride_x, stride_x << ", " << dest.nc() << " " << src.nc()/stride_x);
CHECK_CUDNN(cudnnPoolingForward(context(),
(const cudnnPoolingDescriptor_t)handle,
......@@ -673,7 +695,7 @@ namespace dlib
DLIB_CASSERT(have_same_dimensions(src,grad),"");
const float alpha = 1;
const float beta = 0;
const float beta = 1;
CHECK_CUDNN(cudnnPoolingBackward(context(),
(const cudnnPoolingDescriptor_t)handle,
&alpha,
......
......@@ -328,6 +328,8 @@ namespace dlib
private:
void* handle;
int window_height;
int window_width;
int stride_y;
int stride_x;
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment