Commit cbce85ec authored by Davis King's avatar Davis King

Added GPU versions of the batch normalization functions.

parent 06534305
...@@ -185,7 +185,7 @@ namespace dlib ...@@ -185,7 +185,7 @@ namespace dlib
for (long i = 0; i < num; ++i) for (long i = 0; i < num; ++i)
{ {
auto actual_var = p_invstds[i] - p_means[i]*p_means[i]; auto actual_var = p_invstds[i] - p_means[i]*p_means[i];
p_invstds[i] = 1.0/std::sqrt(actual_var+eps); p_invstds[i] = 1.0f/std::sqrt(actual_var+eps);
} }
p_src = src.host(); p_src = src.host();
...@@ -361,8 +361,8 @@ namespace dlib ...@@ -361,8 +361,8 @@ namespace dlib
// compute variances // compute variances
for (long k = 0; k < src.k(); ++k) for (long k = 0; k < src.k(); ++k)
{ {
auto actual_var = p_invstds[k] - p_means[k]*p_means[k]; float actual_var = p_invstds[k] - p_means[k]*p_means[k];
p_invstds[k] = 1.0/std::sqrt(actual_var + eps); p_invstds[k] = 1.0f/std::sqrt(actual_var + eps);
} }
p_src = src.host(); p_src = src.host();
...@@ -421,7 +421,7 @@ namespace dlib ...@@ -421,7 +421,7 @@ namespace dlib
{ {
for (long k = 0; k < src.k(); ++k) for (long k = 0; k < src.k(); ++k)
{ {
const auto invstd_pow = -0.5*std::pow(p_invstds[k], 3.0f); const float invstd_pow = -0.5*std::pow(p_invstds[k], 3.0f);
for (long i = 0; i < num; ++i) for (long i = 0; i < num; ++i)
{ {
const float x_hat = (*p_src - p_means[k])*p_invstds[k]; const float x_hat = (*p_src - p_means[k])*p_invstds[k];
......
This diff is collapsed.
...@@ -460,6 +460,107 @@ namespace ...@@ -460,6 +460,107 @@ namespace
} }
#endif #endif
// ----------------------------------------------------------------------------------------
void compare_bn_gpu_and_cpu()
{
print_spinner();
resizable_tensor dest, dest2;
resizable_tensor means, means2;
resizable_tensor invstds, invstds2;
resizable_tensor src(64,20,100,100);
resizable_tensor gamma(1,20,100,100);
resizable_tensor beta(1,20,100,100);
gamma = 2;
beta = 3;
tt::tensor_rand rnd;
rnd.fill_uniform(src);
cpu::batch_normalize(dest,means,invstds, src, gamma, beta);
cuda::batch_normalize(dest2,means2,invstds2, src, gamma, beta);
dlog << LINFO << "dest error: "<< max(abs(mat(dest) -mat(dest2)));
dlog << LINFO << "means error: "<< max(abs(mat(means) -mat(means2)));
dlog << LINFO << "invstds error: "<< max(abs(mat(invstds) -mat(invstds2)));
DLIB_TEST(max(abs(mat(dest) -mat(dest2))) < 1e-5);
DLIB_TEST(max(abs(mat(means) -mat(means2))) < 1e-5);
DLIB_TEST(max(abs(mat(invstds) -mat(invstds2))) < 1e-5);
// now check that the gradients match as well
resizable_tensor gradient_input;
resizable_tensor src_grad, gamma_grad, beta_grad;
resizable_tensor src_grad2, gamma_grad2, beta_grad2;
gradient_input.copy_size(dest);
src_grad.copy_size(src); src_grad = 0; src_grad2 = src_grad;
gamma_grad.copy_size(gamma); gamma_grad = 0; gamma_grad2 = gamma_grad;
beta_grad.copy_size(beta); beta_grad = 0; beta_grad2 = beta_grad;
rnd.fill_uniform(gradient_input);
cpu::batch_normalize_gradient cpu_bng;
cpu_bng(gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad);
cuda::batch_normalize_gradient cuda_bng;
cuda_bng(gradient_input, means, invstds, src, gamma, src_grad2, gamma_grad2, beta_grad2);
dlog << LINFO << "src_grad error: " << max(abs(mat(src_grad)-mat(src_grad2)));
dlog << LINFO << "gamma_grad error: " << max(abs(mat(gamma_grad)-mat(gamma_grad2)));
dlog << LINFO << "beta_grad error: " << max(abs(mat(beta_grad)-mat(beta_grad2)));
DLIB_TEST(max(abs(mat(src_grad)-mat(src_grad2))) < 1e-5);
DLIB_TEST(max(abs(mat(gamma_grad)-mat(gamma_grad2))) < 1e-5);
DLIB_TEST(max(abs(mat(beta_grad)-mat(beta_grad2))) < 1e-5);
}
void compare_bn_conv_gpu_and_cpu()
{
print_spinner();
resizable_tensor dest, dest2;
resizable_tensor means, means2;
resizable_tensor invstds, invstds2;
resizable_tensor src(2,8,10,9);
resizable_tensor gamma(1,8);
resizable_tensor beta(1,8);
gamma = 2;
beta = 3;
tt::tensor_rand rnd;
rnd.fill_uniform(src);
cpu::batch_normalize_conv(dest,means,invstds, src, gamma, beta);
cuda::batch_normalize_conv(dest2,means2,invstds2, src, gamma, beta);
dlog << LINFO << "dest error: "<< max(abs(mat(dest) -mat(dest2)));
dlog << LINFO << "means error: "<< max(abs(mat(means) -mat(means2)));
dlog << LINFO << "invstds error: "<< max(abs(mat(invstds) -mat(invstds2)));
DLIB_TEST(max(abs(mat(dest) -mat(dest2))) < 1e-4);
DLIB_TEST(max(abs(mat(means) -mat(means2))) < 1e-4);
DLIB_TEST(max(abs(mat(invstds) -mat(invstds2))) < 1e-4);
resizable_tensor gradient_input;
resizable_tensor src_grad, gamma_grad, beta_grad;
resizable_tensor src_grad2, gamma_grad2, beta_grad2;
gradient_input.copy_size(dest);
src_grad.copy_size(src); src_grad = 0; src_grad2 = src_grad;
gamma_grad.copy_size(gamma); gamma_grad = 0; gamma_grad2 = gamma_grad;
beta_grad.copy_size(beta); beta_grad = 0; beta_grad2 = beta_grad;
rnd.fill_uniform(gradient_input);
cpu::batch_normalize_conv_gradient cpu_bng;
cpu_bng(gradient_input, means, invstds, src, gamma, src_grad, gamma_grad, beta_grad);
cuda::batch_normalize_conv_gradient cuda_bng;
cuda_bng(gradient_input, means, invstds, src, gamma, src_grad2, gamma_grad2, beta_grad2);
dlog << LINFO << "src_grad error: " << max(abs(mat(src_grad)-mat(src_grad2)));
dlog << LINFO << "gamma_grad error: " << max(abs(mat(gamma_grad)-mat(gamma_grad2)));
dlog << LINFO << "beta_grad error: " << max(abs(mat(beta_grad)-mat(beta_grad2)));
DLIB_TEST(max(abs(mat(src_grad)-mat(src_grad2))) < 1e-4);
DLIB_TEST(max(abs(mat(gamma_grad)-mat(gamma_grad2))) < 1e-4);
DLIB_TEST(max(abs(mat(beta_grad)-mat(beta_grad2))) < 1e-4);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class dnn_tester : public tester class dnn_tester : public tester
...@@ -488,6 +589,8 @@ namespace ...@@ -488,6 +589,8 @@ namespace
test_batch_normalize(); test_batch_normalize();
test_batch_normalize_conv(); test_batch_normalize_conv();
test_basic_tensor_ops(); test_basic_tensor_ops();
compare_bn_gpu_and_cpu();
compare_bn_conv_gpu_and_cpu();
} }
} a; } a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment