Commit 10480cb9 authored by Davis King's avatar Davis King

Hid the cuDNN context in a thread local variable so the user doesn't

need to deal with it.
parent d63c4682
...@@ -36,23 +36,36 @@ namespace dlib ...@@ -36,23 +36,36 @@ namespace dlib
} }
} }
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
cudnn_context::cudnn_context() : handle(nullptr) class cudnn_context
{ {
cudnnHandle_t h; public:
check(cudnnCreate(&h)); // not copyable
handle = h; cudnn_context(const cudnn_context&) = delete;
} cudnn_context& operator=(const cudnn_context&) = delete;
cudnn_context::~cudnn_context() cudnn_context()
{ {
if (handle) check(cudnnCreate(&handle));
}
~cudnn_context()
{ {
cudnnDestroy((cudnnHandle_t)handle); cudnnDestroy(handle);
handle = nullptr;
} }
cudnnHandle_t get_handle (
) const { return handle; }
private:
cudnnHandle_t handle;
};
static cudnnHandle_t context()
{
thread_local cudnn_context c;
return c.get_handle();
} }
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
...@@ -136,7 +149,6 @@ namespace dlib ...@@ -136,7 +149,6 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
void add( void add(
cudnn_context& context,
float beta, float beta,
tensor& dest, tensor& dest,
float alpha, float alpha,
...@@ -146,7 +158,6 @@ namespace dlib ...@@ -146,7 +158,6 @@ namespace dlib
} }
void set_tensor ( void set_tensor (
cudnn_context& context,
tensor& t, tensor& t,
float value float value
) )
...@@ -154,7 +165,6 @@ namespace dlib ...@@ -154,7 +165,6 @@ namespace dlib
} }
void scale_tensor ( void scale_tensor (
cudnn_context& context,
tensor& t, tensor& t,
float value float value
) )
...@@ -194,7 +204,6 @@ namespace dlib ...@@ -194,7 +204,6 @@ namespace dlib
void conv:: void conv::
setup( setup(
cudnn_context& context,
const tensor& data, const tensor& data,
const tensor& filters, const tensor& filters,
int stride_y, int stride_y,
...@@ -272,7 +281,6 @@ namespace dlib ...@@ -272,7 +281,6 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
void soft_max ( void soft_max (
cudnn_context& context,
resizable_tensor& dest, resizable_tensor& dest,
const tensor& src const tensor& src
) )
...@@ -280,7 +288,6 @@ namespace dlib ...@@ -280,7 +288,6 @@ namespace dlib
} }
void soft_max_gradient ( void soft_max_gradient (
cudnn_context& context,
tensor& grad, tensor& grad,
const tensor& src, const tensor& src,
const tensor& gradient_input const tensor& gradient_input
...@@ -292,7 +299,6 @@ namespace dlib ...@@ -292,7 +299,6 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
max_pool::max_pool ( max_pool::max_pool (
cudnn_context& context,
int window_height, int window_height,
int window_width, int window_width,
int stride_y, int stride_y,
...@@ -326,7 +332,6 @@ namespace dlib ...@@ -326,7 +332,6 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
void sigmoid ( void sigmoid (
cudnn_context& context,
resizable_tensor& dest, resizable_tensor& dest,
const tensor& src const tensor& src
) )
...@@ -334,7 +339,6 @@ namespace dlib ...@@ -334,7 +339,6 @@ namespace dlib
} }
void sigmoid_gradient ( void sigmoid_gradient (
cudnn_context& context,
tensor& grad, tensor& grad,
const tensor& src, const tensor& src,
const tensor& gradient_input const tensor& gradient_input
...@@ -345,7 +349,6 @@ namespace dlib ...@@ -345,7 +349,6 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
void relu ( void relu (
cudnn_context& context,
resizable_tensor& dest, resizable_tensor& dest,
const tensor& src const tensor& src
) )
...@@ -353,7 +356,6 @@ namespace dlib ...@@ -353,7 +356,6 @@ namespace dlib
} }
void relu_gradient ( void relu_gradient (
cudnn_context& context,
tensor& grad, tensor& grad,
const tensor& src, const tensor& src,
const tensor& gradient_input const tensor& gradient_input
...@@ -364,7 +366,6 @@ namespace dlib ...@@ -364,7 +366,6 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
void tanh ( void tanh (
cudnn_context& context,
resizable_tensor& dest, resizable_tensor& dest,
const tensor& src const tensor& src
) )
...@@ -372,7 +373,6 @@ namespace dlib ...@@ -372,7 +373,6 @@ namespace dlib
} }
void tanh_gradient ( void tanh_gradient (
cudnn_context& context,
tensor& grad, tensor& grad,
const tensor& src, const tensor& src,
const tensor& gradient_input const tensor& gradient_input
......
...@@ -17,31 +17,6 @@ namespace dlib ...@@ -17,31 +17,6 @@ namespace dlib
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
class cudnn_context
{
public:
// not copyable
cudnn_context(const cudnn_context&) = delete;
cudnn_context& operator=(const cudnn_context&) = delete;
// but is movable
cudnn_context(cudnn_context&& item) : cudnn_context() { swap(item); }
cudnn_context& operator=(cudnn_context&& item) { swap(item); return *this; }
cudnn_context();
~cudnn_context();
const void* get_handle (
) const { return handle; }
private:
void swap(cudnn_context& item) { std::swap(handle, item.handle); }
void* handle;
};
// ------------------------------------------------------------------------------------
class tensor_descriptor class tensor_descriptor
{ {
/*! /*!
...@@ -91,7 +66,6 @@ namespace dlib ...@@ -91,7 +66,6 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
void add( void add(
cudnn_context& context,
float beta, float beta,
tensor& dest, tensor& dest,
float alpha, float alpha,
...@@ -117,7 +91,6 @@ namespace dlib ...@@ -117,7 +91,6 @@ namespace dlib
!*/ !*/
void set_tensor ( void set_tensor (
cudnn_context& context,
tensor& t, tensor& t,
float value float value
); );
...@@ -128,7 +101,6 @@ namespace dlib ...@@ -128,7 +101,6 @@ namespace dlib
!*/ !*/
void scale_tensor ( void scale_tensor (
cudnn_context& context,
tensor& t, tensor& t,
float value float value
); );
...@@ -155,7 +127,6 @@ namespace dlib ...@@ -155,7 +127,6 @@ namespace dlib
); );
void setup( void setup(
cudnn_context& context,
const tensor& data, const tensor& data,
const tensor& filters, const tensor& filters,
int stride_y, int stride_y,
...@@ -236,7 +207,6 @@ namespace dlib ...@@ -236,7 +207,6 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
void soft_max ( void soft_max (
cudnn_context& context,
resizable_tensor& dest, resizable_tensor& dest,
const tensor& src const tensor& src
); );
...@@ -245,13 +215,12 @@ namespace dlib ...@@ -245,13 +215,12 @@ namespace dlib
!*/ !*/
void soft_max_gradient ( void soft_max_gradient (
cudnn_context& context,
tensor& grad, tensor& grad,
const tensor& src, const tensor& src,
const tensor& gradient_input const tensor& gradient_input
); );
/*! /*!
- let OUT be the output of soft_max(context,OUT,src) - let OUT be the output of soft_max(OUT,src)
- let f(src) == dot(gradient_input,OUT) - let f(src) == dot(gradient_input,OUT)
- Then this function computes the gradient of f() with respect to src - Then this function computes the gradient of f() with respect to src
and adds it to grad. and adds it to grad.
...@@ -271,7 +240,6 @@ namespace dlib ...@@ -271,7 +240,6 @@ namespace dlib
// cudnnCreatePoolingDescriptor(), cudnnSetPooling2dDescriptor() // cudnnCreatePoolingDescriptor(), cudnnSetPooling2dDescriptor()
max_pool ( max_pool (
cudnn_context& context,
int window_height, int window_height,
int window_width, int window_width,
int stride_y, int stride_y,
...@@ -310,7 +278,6 @@ namespace dlib ...@@ -310,7 +278,6 @@ namespace dlib
// cudnnActivationForward(), CUDNN_ACTIVATION_SIGMOID // cudnnActivationForward(), CUDNN_ACTIVATION_SIGMOID
void sigmoid ( void sigmoid (
cudnn_context& context,
resizable_tensor& dest, resizable_tensor& dest,
const tensor& src const tensor& src
); );
...@@ -323,7 +290,6 @@ namespace dlib ...@@ -323,7 +290,6 @@ namespace dlib
// cudnnActivationBackward() // cudnnActivationBackward()
void sigmoid_gradient ( void sigmoid_gradient (
cudnn_context& context,
tensor& grad, tensor& grad,
const tensor& src, const tensor& src,
const tensor& gradient_input const tensor& gradient_input
...@@ -333,7 +299,7 @@ namespace dlib ...@@ -333,7 +299,7 @@ namespace dlib
- have_same_dimensions(src,gradient_input) == true - have_same_dimensions(src,gradient_input) == true
- have_same_dimensions(src,grad) == true - have_same_dimensions(src,grad) == true
ensures ensures
- let OUT be the output of sigmoid(context,OUT,src) - let OUT be the output of sigmoid(OUT,src)
- let f(src) == dot(gradient_input,OUT) - let f(src) == dot(gradient_input,OUT)
- Then this function computes the gradient of f() with respect to src and - Then this function computes the gradient of f() with respect to src and
adds it to grad. adds it to grad.
...@@ -343,7 +309,6 @@ namespace dlib ...@@ -343,7 +309,6 @@ namespace dlib
// cudnnActivationForward(), CUDNN_ACTIVATION_RELU // cudnnActivationForward(), CUDNN_ACTIVATION_RELU
void relu ( void relu (
cudnn_context& context,
resizable_tensor& dest, resizable_tensor& dest,
const tensor& src const tensor& src
); );
...@@ -356,7 +321,6 @@ namespace dlib ...@@ -356,7 +321,6 @@ namespace dlib
// cudnnActivationBackward() // cudnnActivationBackward()
void relu_gradient ( void relu_gradient (
cudnn_context& context,
tensor& grad, tensor& grad,
const tensor& src, const tensor& src,
const tensor& gradient_input const tensor& gradient_input
...@@ -366,7 +330,7 @@ namespace dlib ...@@ -366,7 +330,7 @@ namespace dlib
- have_same_dimensions(src,gradient_input) == true - have_same_dimensions(src,gradient_input) == true
- have_same_dimensions(src,grad) == true - have_same_dimensions(src,grad) == true
ensures ensures
- let OUT be the output of relu(context,OUT,src) - let OUT be the output of relu(OUT,src)
- let f(src) == dot(gradient_input,OUT) - let f(src) == dot(gradient_input,OUT)
- Then this function computes the gradient of f() with respect to src and - Then this function computes the gradient of f() with respect to src and
adds it to grad. adds it to grad.
...@@ -376,7 +340,6 @@ namespace dlib ...@@ -376,7 +340,6 @@ namespace dlib
// cudnnActivationForward(), CUDNN_ACTIVATION_TANH // cudnnActivationForward(), CUDNN_ACTIVATION_TANH
void tanh ( void tanh (
cudnn_context& context,
resizable_tensor& dest, resizable_tensor& dest,
const tensor& src const tensor& src
); );
...@@ -389,7 +352,6 @@ namespace dlib ...@@ -389,7 +352,6 @@ namespace dlib
// cudnnActivationBackward() // cudnnActivationBackward()
void tanh_gradient ( void tanh_gradient (
cudnn_context& context,
tensor& grad, tensor& grad,
const tensor& src, const tensor& src,
const tensor& gradient_input const tensor& gradient_input
...@@ -399,7 +361,7 @@ namespace dlib ...@@ -399,7 +361,7 @@ namespace dlib
- have_same_dimensions(src,gradient_input) == true - have_same_dimensions(src,gradient_input) == true
- have_same_dimensions(src,grad) == true - have_same_dimensions(src,grad) == true
ensures ensures
- let OUT be the output of tanh(context,OUT,src) - let OUT be the output of tanh(OUT,src)
- let f(src) == dot(gradient_input,OUT) - let f(src) == dot(gradient_input,OUT)
- Then this function computes the gradient of f() with respect to src and - Then this function computes the gradient of f() with respect to src and
adds it to grad. adds it to grad.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment