Commit 04797b66 authored by Davis King's avatar Davis King

Added softmax

parent 91234133
......@@ -499,24 +499,6 @@ namespace dlib
filters_gradient.device()));
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
void soft_max (
resizable_tensor& dest,
const tensor& src
)
{
}
void soft_max_gradient (
tensor& grad,
const tensor& src,
const tensor& gradient_input
)
{
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
......@@ -550,6 +532,60 @@ namespace dlib
{
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
void softmax (
resizable_tensor& dest,
const tensor& src
)
{
dest.copy_size(src);
if (src.size() == 0)
return;
const float alpha = 1;
const float beta = 0;
check(cudnnSoftmaxForward(context(),
CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
descriptor(src),
src.device(),
&beta,
descriptor(dest),
dest.device()));
}
void softmax_gradient (
tensor& grad,
const tensor& softmaxed_data,
const tensor& gradient_input
)
{
DLIB_CASSERT(
have_same_dimensions(softmaxed_data,gradient_input) == true &&
have_same_dimensions(softmaxed_data,grad) == true , "");
if (softmaxed_data.size() == 0)
return;
const float alpha = 1;
const float beta = 1;
check(cudnnSoftmaxBackward(context(),
CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_CHANNEL,
&alpha,
descriptor(softmaxed_data),
softmaxed_data.device(),
descriptor(gradient_input),
gradient_input.device(),
&beta,
descriptor(grad),
grad.device()));
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
......
......@@ -151,6 +151,8 @@ namespace dlib
/*!
requires
- filters.k() == data.k()
- stride_y > 0
- stride_x > 0
!*/
~conv (
......@@ -237,28 +239,6 @@ namespace dlib
void* backward_filters_workspace;
};
// ------------------------------------------------------------------------------------
void soft_max (
resizable_tensor& dest,
const tensor& src
);
/*!
probably uses CUDNN_SOFTMAX_MODE_CHANNEL
!*/
void soft_max_gradient (
tensor& grad,
const tensor& src,
const tensor& gradient_input
);
/*!
- let OUT be the output of soft_max(OUT,src)
- let f(src) == dot(gradient_input,OUT)
- Then this function computes the gradient of f() with respect to src
and adds it to grad.
!*/
// ------------------------------------------------------------------------------------
class max_pool
......@@ -307,6 +287,41 @@ namespace dlib
// TODO, make the order of parameters of all these functions consistent.
// ------------------------------------------------------------------------------------
void softmax (
resizable_tensor& dest,
const tensor& src
);
/*!
ensures
- have_same_dimensions(#dest, src) == true
- Note that the softmax function is a vector valued function:
s(x) == exp(x)/sum(exp(x))
- Computes the softmax function on src and writes the results to dest. The
softmax is computed per spatial location across the different channels at
each location. That is, softmax() outputs a new tensor, #dest, where
each of the spatial locations in dest (i.e. image idx, row idx, and
column idx) contains the output of s() evaluated over the channel values
at each location.
!*/
void softmax_gradient (
tensor& grad,
const tensor& softmaxed_data,
const tensor& gradient_input
);
/*!
requires
- have_same_dimensions(softmaxed_data,gradient_input) == true
- have_same_dimensions(softmaxed_data,grad) == true
ensures
- We interpret softmaxed_data as the output of softmax(softmaxed_data,SRC)
for some SRC tensor. Then let f(SRC) == dot(gradient_input,softmaxed_data)
Then this function computes the gradient of f() with respect to SRC and
adds it to grad.
!*/
// ------------------------------------------------------------------------------------
void sigmoid (
......@@ -334,7 +349,7 @@ namespace dlib
- dest contains the result of calling sigmoid(dest,src)
ensures
- Recalling that dest is the output of sigmoid(dest,src),
let f(src) == dot(gradient_input,dist)
let f(src) == dot(gradient_input,dest)
- Then this function computes the gradient of f() with respect to src and
adds it to grad.
!*/
......@@ -366,7 +381,7 @@ namespace dlib
- dest contains the result of calling relu(dest,src)
ensures
- Recalling that dest is the output of relu(dest,src),
let f(src) == dot(gradient_input,dist)
let f(src) == dot(gradient_input,dest)
- Then this function computes the gradient of f() with respect to src and
adds it to grad.
!*/
......@@ -398,7 +413,7 @@ namespace dlib
- dest contains the result of calling tanh(dest,src)
ensures
- Recalling that dest is the output of tanh(dest,src),
let f(src) == dot(gradient_input,dist)
let f(src) == dot(gradient_input,dest)
- Then this function computes the gradient of f() with respect to src and
adds it to grad.
!*/
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment