Commit 91234133 authored by Davis King's avatar Davis King

Added cuDNN activation functions

parent 45b2c06a
...@@ -558,14 +558,50 @@ namespace dlib ...@@ -558,14 +558,50 @@ namespace dlib
const tensor& src const tensor& src
) )
{ {
dest.copy_size(src);
if (src.size() == 0)
return;
const float alpha = 1;
const float beta = 0;
check(cudnnActivationForward(context(),
CUDNN_ACTIVATION_SIGMOID,
&alpha,
descriptor(src),
src.device(),
&beta,
descriptor(dest),
dest.device()));
} }
void sigmoid_gradient ( void sigmoid_gradient (
tensor& grad, tensor& grad,
const tensor& dest,
const tensor& src, const tensor& src,
const tensor& gradient_input const tensor& gradient_input
) )
{ {
DLIB_CASSERT(
have_same_dimensions(src,gradient_input) == true &&
have_same_dimensions(src,grad) == true &&
have_same_dimensions(src,dest) == true , "");
if (src.size() == 0)
return;
const float alpha = 1;
const float beta = 1;
check(cudnnActivationBackward(context(),
CUDNN_ACTIVATION_SIGMOID,
&alpha,
descriptor(dest),
dest.device(),
descriptor(gradient_input),
gradient_input.device(),
descriptor(src),
src.device(),
&beta,
descriptor(grad),
grad.device()));
} }
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
...@@ -575,14 +611,50 @@ namespace dlib ...@@ -575,14 +611,50 @@ namespace dlib
const tensor& src const tensor& src
) )
{ {
dest.copy_size(src);
if (src.size() == 0)
return;
const float alpha = 1;
const float beta = 0;
check(cudnnActivationForward(context(),
CUDNN_ACTIVATION_RELU,
&alpha,
descriptor(src),
src.device(),
&beta,
descriptor(dest),
dest.device()));
} }
void relu_gradient ( void relu_gradient (
tensor& grad, tensor& grad,
const tensor& dest,
const tensor& src, const tensor& src,
const tensor& gradient_input const tensor& gradient_input
) )
{ {
DLIB_CASSERT(
have_same_dimensions(src,gradient_input) == true &&
have_same_dimensions(src,grad) == true &&
have_same_dimensions(src,dest) == true , "");
if (src.size() == 0)
return;
const float alpha = 1;
const float beta = 1;
check(cudnnActivationBackward(context(),
CUDNN_ACTIVATION_RELU,
&alpha,
descriptor(dest),
dest.device(),
descriptor(gradient_input),
gradient_input.device(),
descriptor(src),
src.device(),
&beta,
descriptor(grad),
grad.device()));
} }
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
...@@ -592,14 +664,50 @@ namespace dlib ...@@ -592,14 +664,50 @@ namespace dlib
const tensor& src const tensor& src
) )
{ {
dest.copy_size(src);
if (src.size() == 0)
return;
const float alpha = 1;
const float beta = 0;
check(cudnnActivationForward(context(),
CUDNN_ACTIVATION_TANH,
&alpha,
descriptor(src),
src.device(),
&beta,
descriptor(dest),
dest.device()));
} }
void tanh_gradient ( void tanh_gradient (
tensor& grad, tensor& grad,
const tensor& dest,
const tensor& src, const tensor& src,
const tensor& gradient_input const tensor& gradient_input
) )
{ {
DLIB_CASSERT(
have_same_dimensions(src,gradient_input) == true &&
have_same_dimensions(src,grad) == true &&
have_same_dimensions(src,dest) == true , "");
if (src.size() == 0)
return;
const float alpha = 1;
const float beta = 1;
check(cudnnActivationBackward(context(),
CUDNN_ACTIVATION_TANH,
&alpha,
descriptor(dest),
dest.device(),
descriptor(gradient_input),
gradient_input.device(),
descriptor(src),
src.device(),
&beta,
descriptor(grad),
grad.device()));
} }
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
......
...@@ -309,7 +309,6 @@ namespace dlib ...@@ -309,7 +309,6 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// cudnnActivationForward(), CUDNN_ACTIVATION_SIGMOID
void sigmoid ( void sigmoid (
resizable_tensor& dest, resizable_tensor& dest,
const tensor& src const tensor& src
...@@ -321,9 +320,9 @@ namespace dlib ...@@ -321,9 +320,9 @@ namespace dlib
- #dest.host()[i] == 1/(1+std::exp(-src.host()[i])) - #dest.host()[i] == 1/(1+std::exp(-src.host()[i]))
!*/ !*/
// cudnnActivationBackward()
void sigmoid_gradient ( void sigmoid_gradient (
tensor& grad, tensor& grad,
const tensor& dest,
const tensor& src, const tensor& src,
const tensor& gradient_input const tensor& gradient_input
); );
...@@ -331,16 +330,17 @@ namespace dlib ...@@ -331,16 +330,17 @@ namespace dlib
requires requires
- have_same_dimensions(src,gradient_input) == true - have_same_dimensions(src,gradient_input) == true
- have_same_dimensions(src,grad) == true - have_same_dimensions(src,grad) == true
- have_same_dimensions(src,dest) == true
- dest contains the result of calling sigmoid(dest,src)
ensures ensures
- let OUT be the output of sigmoid(OUT,src) - Recalling that dest is the output of sigmoid(dest,src),
- let f(src) == dot(gradient_input,OUT) let f(src) == dot(gradient_input,dist)
- Then this function computes the gradient of f() with respect to src and - Then this function computes the gradient of f() with respect to src and
adds it to grad. adds it to grad.
!*/ !*/
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// cudnnActivationForward(), CUDNN_ACTIVATION_RELU
void relu ( void relu (
resizable_tensor& dest, resizable_tensor& dest,
const tensor& src const tensor& src
...@@ -352,9 +352,9 @@ namespace dlib ...@@ -352,9 +352,9 @@ namespace dlib
- #dest.host()[i] == std::max(0,src.host()[i]) - #dest.host()[i] == std::max(0,src.host()[i])
!*/ !*/
// cudnnActivationBackward()
void relu_gradient ( void relu_gradient (
tensor& grad, tensor& grad,
const tensor& dest,
const tensor& src, const tensor& src,
const tensor& gradient_input const tensor& gradient_input
); );
...@@ -362,16 +362,17 @@ namespace dlib ...@@ -362,16 +362,17 @@ namespace dlib
requires requires
- have_same_dimensions(src,gradient_input) == true - have_same_dimensions(src,gradient_input) == true
- have_same_dimensions(src,grad) == true - have_same_dimensions(src,grad) == true
- have_same_dimensions(src,dest) == true
- dest contains the result of calling relu(dest,src)
ensures ensures
- let OUT be the output of relu(OUT,src) - Recalling that dest is the output of relu(dest,src),
- let f(src) == dot(gradient_input,OUT) let f(src) == dot(gradient_input,dist)
- Then this function computes the gradient of f() with respect to src and - Then this function computes the gradient of f() with respect to src and
adds it to grad. adds it to grad.
!*/ !*/
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// cudnnActivationForward(), CUDNN_ACTIVATION_TANH
void tanh ( void tanh (
resizable_tensor& dest, resizable_tensor& dest,
const tensor& src const tensor& src
...@@ -383,9 +384,9 @@ namespace dlib ...@@ -383,9 +384,9 @@ namespace dlib
- #dest.host()[i] == std::tanh(src.host()[i]) - #dest.host()[i] == std::tanh(src.host()[i])
!*/ !*/
// cudnnActivationBackward()
void tanh_gradient ( void tanh_gradient (
tensor& grad, tensor& grad,
const tensor& dest,
const tensor& src, const tensor& src,
const tensor& gradient_input const tensor& gradient_input
); );
...@@ -393,9 +394,11 @@ namespace dlib ...@@ -393,9 +394,11 @@ namespace dlib
requires requires
- have_same_dimensions(src,gradient_input) == true - have_same_dimensions(src,gradient_input) == true
- have_same_dimensions(src,grad) == true - have_same_dimensions(src,grad) == true
- have_same_dimensions(src,dest) == true
- dest contains the result of calling tanh(dest,src)
ensures ensures
- let OUT be the output of tanh(OUT,src) - Recalling that dest is the output of tanh(dest,src),
- let f(src) == dot(gradient_input,OUT) let f(src) == dot(gradient_input,dist)
- Then this function computes the gradient of f() with respect to src and - Then this function computes the gradient of f() with respect to src and
adds it to grad. adds it to grad.
!*/ !*/
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment