Commit 51ea50b3 authored by Davis King's avatar Davis King

More batch normalization optimizations.

parent 32125dea
......@@ -210,13 +210,14 @@ namespace dlib
const float dx = *p_grad * p_gamma[i];
p_dvars[i] += dx*(*p_src - p_means[i])* -0.5*std::pow(p_invstds[i], 3.0f);
p_dvars[i] += dx*(*p_src - p_means[i])*-0.5*std::pow(p_invstds[i], 3.0f);
++p_grad;
++p_src;
}
}
const float invnum = 1.0f/src.num_samples();
p_grad = gradient_input.host();
p_src = src.host();
for (long n = 0; n < src.num_samples(); ++n)
......@@ -225,7 +226,7 @@ namespace dlib
{
const float dx = *p_grad * p_gamma[i];
p_dmeans[i] += dx*-p_invstds[i] + p_dvars[i] * -2*(*p_src - p_means[i])/src.num_samples();
p_dmeans[i] += dx*-p_invstds[i] + p_dvars[i] * -2*(*p_src - p_means[i])*invnum;
++p_grad;
++p_src;
......@@ -241,8 +242,8 @@ namespace dlib
const float dx = *p_grad * p_gamma[i];
*p_src_grad += dx*p_invstds[i] +
p_dvars[i] *2*(*p_src - p_means[i])/src.num_samples() +
p_dmeans[i]/src.num_samples();
p_dvars[i] *2*(*p_src - p_means[i])*invnum +
p_dmeans[i]*invnum;
++p_grad;
......@@ -382,6 +383,7 @@ namespace dlib
{
for (long k = 0; k < src.k(); ++k)
{
const auto invstd_pow = -0.5*std::pow(p_invstds[k], 3.0f);
for (long i = 0; i < num; ++i)
{
const float x_hat = (*p_src - p_means[k])*p_invstds[k];
......@@ -390,7 +392,7 @@ namespace dlib
const float dx = *p_grad * p_gamma[k];
p_dvars[k] += dx*(*p_src - p_means[k])* -0.5*std::pow(p_invstds[k], 3.0f);
p_dvars[k] += dx*(*p_src - p_means[k])*invstd_pow;
++p_grad;
++p_src;
......@@ -400,6 +402,7 @@ namespace dlib
p_grad = gradient_input.host();
p_src = src.host();
const float invnum = 1.0f/(src.num_samples()*num);
for (long n = 0; n < src.num_samples(); ++n)
{
for (long k = 0; k < src.k(); ++k)
......@@ -408,7 +411,7 @@ namespace dlib
{
const float dx = *p_grad * p_gamma[k];
p_dmeans[k] += -dx*p_invstds[k] + p_dvars[k] * -2*(*p_src - p_means[k])/src.num_samples()/num;
p_dmeans[k] += -dx*p_invstds[k] + p_dvars[k] * -2*(*p_src - p_means[k])*invnum;
++p_grad;
++p_src;
......@@ -427,8 +430,8 @@ namespace dlib
const float dx = *p_grad * p_gamma[k];
*p_src_grad += dx*p_invstds[k] +
p_dvars[k] *2*(*p_src - p_means[k])/src.num_samples()/num +
p_dmeans[k]/src.num_samples()/num;
p_dvars[k]*2*(*p_src - p_means[k])*invnum +
p_dmeans[k]*invnum;
++p_grad;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment