Commit 10480cb9 authored by Davis King's avatar Davis King

Hid the cuDNN context in a thread local variable so the user doesn't

need to deal with it.
parent d63c4682
......@@ -36,23 +36,36 @@ namespace dlib
}
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
cudnn_context::cudnn_context() : handle(nullptr)
class cudnn_context
{
cudnnHandle_t h;
check(cudnnCreate(&h));
handle = h;
}
public:
// not copyable
cudnn_context(const cudnn_context&) = delete;
cudnn_context& operator=(const cudnn_context&) = delete;
cudnn_context::~cudnn_context()
cudnn_context()
{
if (handle)
check(cudnnCreate(&handle));
}
~cudnn_context()
{
cudnnDestroy((cudnnHandle_t)handle);
handle = nullptr;
cudnnDestroy(handle);
}
cudnnHandle_t get_handle (
) const { return handle; }
private:
cudnnHandle_t handle;
};
static cudnnHandle_t context()
{
thread_local cudnn_context c;
return c.get_handle();
}
// ------------------------------------------------------------------------------------
......@@ -136,7 +149,6 @@ namespace dlib
// ------------------------------------------------------------------------------------
void add(
cudnn_context& context,
float beta,
tensor& dest,
float alpha,
......@@ -146,7 +158,6 @@ namespace dlib
}
void set_tensor (
cudnn_context& context,
tensor& t,
float value
)
......@@ -154,7 +165,6 @@ namespace dlib
}
void scale_tensor (
cudnn_context& context,
tensor& t,
float value
)
......@@ -194,7 +204,6 @@ namespace dlib
void conv::
setup(
cudnn_context& context,
const tensor& data,
const tensor& filters,
int stride_y,
......@@ -272,7 +281,6 @@ namespace dlib
// ------------------------------------------------------------------------------------
void soft_max (
cudnn_context& context,
resizable_tensor& dest,
const tensor& src
)
......@@ -280,7 +288,6 @@ namespace dlib
}
void soft_max_gradient (
cudnn_context& context,
tensor& grad,
const tensor& src,
const tensor& gradient_input
......@@ -292,7 +299,6 @@ namespace dlib
// ------------------------------------------------------------------------------------
max_pool::max_pool (
cudnn_context& context,
int window_height,
int window_width,
int stride_y,
......@@ -326,7 +332,6 @@ namespace dlib
// ------------------------------------------------------------------------------------
void sigmoid (
cudnn_context& context,
resizable_tensor& dest,
const tensor& src
)
......@@ -334,7 +339,6 @@ namespace dlib
}
void sigmoid_gradient (
cudnn_context& context,
tensor& grad,
const tensor& src,
const tensor& gradient_input
......@@ -345,7 +349,6 @@ namespace dlib
// ------------------------------------------------------------------------------------
void relu (
cudnn_context& context,
resizable_tensor& dest,
const tensor& src
)
......@@ -353,7 +356,6 @@ namespace dlib
}
void relu_gradient (
cudnn_context& context,
tensor& grad,
const tensor& src,
const tensor& gradient_input
......@@ -364,7 +366,6 @@ namespace dlib
// ------------------------------------------------------------------------------------
void tanh (
cudnn_context& context,
resizable_tensor& dest,
const tensor& src
)
......@@ -372,7 +373,6 @@ namespace dlib
}
void tanh_gradient (
cudnn_context& context,
tensor& grad,
const tensor& src,
const tensor& gradient_input
......
......@@ -17,31 +17,6 @@ namespace dlib
// -----------------------------------------------------------------------------------
class cudnn_context
{
public:
// not copyable
cudnn_context(const cudnn_context&) = delete;
cudnn_context& operator=(const cudnn_context&) = delete;
// but is movable
cudnn_context(cudnn_context&& item) : cudnn_context() { swap(item); }
cudnn_context& operator=(cudnn_context&& item) { swap(item); return *this; }
cudnn_context();
~cudnn_context();
const void* get_handle (
) const { return handle; }
private:
void swap(cudnn_context& item) { std::swap(handle, item.handle); }
void* handle;
};
// ------------------------------------------------------------------------------------
class tensor_descriptor
{
/*!
......@@ -91,7 +66,6 @@ namespace dlib
// ------------------------------------------------------------------------------------
void add(
cudnn_context& context,
float beta,
tensor& dest,
float alpha,
......@@ -117,7 +91,6 @@ namespace dlib
!*/
void set_tensor (
cudnn_context& context,
tensor& t,
float value
);
......@@ -128,7 +101,6 @@ namespace dlib
!*/
void scale_tensor (
cudnn_context& context,
tensor& t,
float value
);
......@@ -155,7 +127,6 @@ namespace dlib
);
void setup(
cudnn_context& context,
const tensor& data,
const tensor& filters,
int stride_y,
......@@ -236,7 +207,6 @@ namespace dlib
// ------------------------------------------------------------------------------------
void soft_max (
cudnn_context& context,
resizable_tensor& dest,
const tensor& src
);
......@@ -245,13 +215,12 @@ namespace dlib
!*/
void soft_max_gradient (
cudnn_context& context,
tensor& grad,
const tensor& src,
const tensor& gradient_input
);
/*!
- let OUT be the output of soft_max(context,OUT,src)
- let OUT be the output of soft_max(OUT,src)
- let f(src) == dot(gradient_input,OUT)
- Then this function computes the gradient of f() with respect to src
and adds it to grad.
......@@ -271,7 +240,6 @@ namespace dlib
// cudnnCreatePoolingDescriptor(), cudnnSetPooling2dDescriptor()
max_pool (
cudnn_context& context,
int window_height,
int window_width,
int stride_y,
......@@ -310,7 +278,6 @@ namespace dlib
// cudnnActivationForward(), CUDNN_ACTIVATION_SIGMOID
void sigmoid (
cudnn_context& context,
resizable_tensor& dest,
const tensor& src
);
......@@ -323,7 +290,6 @@ namespace dlib
// cudnnActivationBackward()
void sigmoid_gradient (
cudnn_context& context,
tensor& grad,
const tensor& src,
const tensor& gradient_input
......@@ -333,7 +299,7 @@ namespace dlib
- have_same_dimensions(src,gradient_input) == true
- have_same_dimensions(src,grad) == true
ensures
- let OUT be the output of sigmoid(context,OUT,src)
- let OUT be the output of sigmoid(OUT,src)
- let f(src) == dot(gradient_input,OUT)
- Then this function computes the gradient of f() with respect to src and
adds it to grad.
......@@ -343,7 +309,6 @@ namespace dlib
// cudnnActivationForward(), CUDNN_ACTIVATION_RELU
void relu (
cudnn_context& context,
resizable_tensor& dest,
const tensor& src
);
......@@ -356,7 +321,6 @@ namespace dlib
// cudnnActivationBackward()
void relu_gradient (
cudnn_context& context,
tensor& grad,
const tensor& src,
const tensor& gradient_input
......@@ -366,7 +330,7 @@ namespace dlib
- have_same_dimensions(src,gradient_input) == true
- have_same_dimensions(src,grad) == true
ensures
- let OUT be the output of relu(context,OUT,src)
- let OUT be the output of relu(OUT,src)
- let f(src) == dot(gradient_input,OUT)
- Then this function computes the gradient of f() with respect to src and
adds it to grad.
......@@ -376,7 +340,6 @@ namespace dlib
// cudnnActivationForward(), CUDNN_ACTIVATION_TANH
void tanh (
cudnn_context& context,
resizable_tensor& dest,
const tensor& src
);
......@@ -389,7 +352,6 @@ namespace dlib
// cudnnActivationBackward()
void tanh_gradient (
cudnn_context& context,
tensor& grad,
const tensor& src,
const tensor& gradient_input
......@@ -399,7 +361,7 @@ namespace dlib
- have_same_dimensions(src,gradient_input) == true
- have_same_dimensions(src,grad) == true
ensures
- let OUT be the output of tanh(context,OUT,src)
- let OUT be the output of tanh(OUT,src)
- let f(src) == dot(gradient_input,OUT)
- Then this function computes the gradient of f() with respect to src and
adds it to grad.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment