Commit 920cf4db authored by Davis King's avatar Davis King

Added ADAM tests

parent ab605d15
...@@ -568,6 +568,36 @@ namespace ...@@ -568,6 +568,36 @@ namespace
} }
} }
void compare_adam()
{
float t = 2;
tt::tensor_rand rnd;
resizable_tensor s, m, v, params, params_grad;
s.set_size(89,90,60,73);
m.copy_size(s);
v.copy_size(s);
params.copy_size(s);
params_grad.copy_size(s);
rnd.fill_uniform(s);
rnd.fill_uniform(m);
rnd.fill_uniform(v);
rnd.fill_uniform(params);
rnd.fill_uniform(params_grad);
resizable_tensor mm(m), vv(v);
cpu::compute_adam_update(s, mm, vv, t, 0.01, 0.001, 0.9, 0.99, params, params_grad);
matrix<float> s1 = mat(s);
rnd.fill_uniform(s);
cuda::compute_adam_update(s, m, v, t, 0.01, 0.001, 0.9, 0.99, params, params_grad);
matrix<float> s2 = mat(s);
DLIB_TEST_MSG(max(abs(s1-s2)) < 1e-6, max(abs(s1-s2)));
DLIB_TEST_MSG(max(abs(mat(m)-mat(mm))) < 1e-6, max(abs(mat(m)-mat(mm))));
DLIB_TEST_MSG(max(abs(mat(v)-mat(vv))) < 1e-6, max(abs(mat(v)-mat(vv))));
}
void test_add() void test_add()
{ {
print_spinner(); print_spinner();
...@@ -1190,6 +1220,7 @@ namespace ...@@ -1190,6 +1220,7 @@ namespace
compare_bn_gpu_and_cpu(); compare_bn_gpu_and_cpu();
compare_bn_conv_gpu_and_cpu(); compare_bn_conv_gpu_and_cpu();
test_add(); test_add();
compare_adam();
#endif #endif
test_max_pool(1,1,2,3); test_max_pool(1,1,2,3);
test_max_pool(3,3,1,1); test_max_pool(3,3,1,1);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment