Commit a7ea7d00 authored by Davis King's avatar Davis King

Implemented CPU version of tanh

parent 9b36bb98
...@@ -650,8 +650,10 @@ namespace dlib ...@@ -650,8 +650,10 @@ namespace dlib
const tensor& src const tensor& src
) )
{ {
// TODO const auto d = dest.host();
DLIB_CASSERT(false,""); const auto s = src.host();
for (size_t i = 0; i < src.size(); ++i)
d[i] = std::tanh(s[i]);
} }
void tanh_gradient ( void tanh_gradient (
...@@ -660,8 +662,11 @@ namespace dlib ...@@ -660,8 +662,11 @@ namespace dlib
const tensor& gradient_input const tensor& gradient_input
) )
{ {
// TODO const auto g = grad.host();
DLIB_CASSERT(false,""); const auto d = dest.host();
const auto in = gradient_input.host();
for (size_t i = 0; i < dest.size(); ++i)
g[i] = in[i]*(1-d[i]*d[i]);
} }
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
......
...@@ -39,6 +39,43 @@ namespace ...@@ -39,6 +39,43 @@ namespace
return max_error; return max_error;
} }
// ----------------------------------------------------------------------------------------
void test_tanh()
{
print_spinner();
resizable_tensor src(5,5), dest(5,5), gradient_input(5,5);
src = matrix_cast<float>(gaussian_randm(5,5, 0));
dest = matrix_cast<float>(gaussian_randm(5,5, 1));
gradient_input = matrix_cast<float>(gaussian_randm(5,5, 2));
auto grad_src = [&](long idx) {
auto f = [&](float eps) {
const float old = src.host()[idx];
src.host()[idx] += eps;
tanh(dest, src);
float result = dot(gradient_input, dest);
src.host()[idx] = old;
return result;
};
const float eps = 0.01;
return (f(+eps)-f(-eps))/(2*eps);
};
resizable_tensor src_grad;
src_grad.copy_size(src);
src_grad = 0;
tanh(dest, src);
tanh_gradient(src_grad, dest, gradient_input);
auto grad_error = compare_gradients(src_grad, grad_src);
dlog << LINFO << "src error: " << grad_error;
DLIB_TEST(grad_error < 0.001);
}
void test_sigmoid() void test_sigmoid()
{ {
print_spinner(); print_spinner();
...@@ -324,6 +361,7 @@ namespace ...@@ -324,6 +361,7 @@ namespace
void perform_test ( void perform_test (
) )
{ {
test_tanh();
test_softmax(); test_softmax();
test_sigmoid(); test_sigmoid();
test_batch_normalize(); test_batch_normalize();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment