Commit 15fb2527 authored by Davis King's avatar Davis King

Gave the batch normalization layer an automatic testing mode that causes

it to use the saved average mean and invstd to scale the data instead of
normalizing the batch.
parent 2e01920f
...@@ -301,7 +301,7 @@ namespace dlib ...@@ -301,7 +301,7 @@ namespace dlib
class bn_ class bn_
{ {
public: public:
bn_() bn_() : num_updates(0), running_stats_window_size(1000), running_nim_out_of_date(true)
{} {}
template <typename SUBNET> template <typename SUBNET>
...@@ -317,6 +317,16 @@ namespace dlib ...@@ -317,6 +317,16 @@ namespace dlib
gamma(params,0) = 1; gamma(params,0) = 1;
beta(params,gamma.size()) = 0; beta(params,gamma.size()) = 0;
running_means.set_size(1,
sub.get_output().k(),
sub.get_output().nr(),
sub.get_output().nc());
running_invstds.copy_size(running_means);
running_means = 0;
running_invstds = 1;
num_updates = 0;
running_nim_out_of_date = true;
} }
template <typename SUBNET> template <typename SUBNET>
...@@ -324,7 +334,30 @@ namespace dlib ...@@ -324,7 +334,30 @@ namespace dlib
{ {
auto g = gamma(params,0); auto g = gamma(params,0);
auto b = beta(params,gamma.size()); auto b = beta(params,gamma.size());
if (sub.get_output().num_samples() > 1)
{
tt::batch_normalize(output, means, invstds, sub.get_output(), g, b); tt::batch_normalize(output, means, invstds, sub.get_output(), g, b);
const double decay = num_updates/(num_updates+1.0);
if (num_updates <running_stats_window_size)
++num_updates;
tt::affine_transform(running_means, running_means, means, decay, 1-decay, 0);
tt::affine_transform(running_invstds, running_invstds, invstds, decay, 1-decay, 0);
running_nim_out_of_date = true;
}
else // we are running in testing mode so we just linearly scale the input tensor.
{
if (running_nim_out_of_date)
{
running_nim_out_of_date = false;
running_nim.copy_size(running_means);
tt::multiply(running_nim, running_means, running_invstds);
running_nim *= -1;
}
output.copy_size(sub.get_output());
tt::affine_transform(output, sub.get_output(), running_invstds, running_nim);
tt::affine_transform(output, output, g, b);
}
} }
template <typename SUBNET> template <typename SUBNET>
...@@ -347,6 +380,10 @@ namespace dlib ...@@ -347,6 +380,10 @@ namespace dlib
serialize(item.beta, out); serialize(item.beta, out);
serialize(item.means, out); serialize(item.means, out);
serialize(item.invstds, out); serialize(item.invstds, out);
serialize(item.running_means, out);
serialize(item.running_invstds, out);
serialize(item.num_updates, out);
serialize(item.running_stats_window_size, out);
} }
friend void deserialize(bn_& item, std::istream& in) friend void deserialize(bn_& item, std::istream& in)
...@@ -360,6 +397,11 @@ namespace dlib ...@@ -360,6 +397,11 @@ namespace dlib
deserialize(item.beta, in); deserialize(item.beta, in);
deserialize(item.means, in); deserialize(item.means, in);
deserialize(item.invstds, in); deserialize(item.invstds, in);
deserialize(item.running_means, in);
deserialize(item.running_invstds, in);
deserialize(item.num_updates, in);
deserialize(item.running_stats_window_size, in);
item.running_nim_out_of_date = true;
} }
private: private:
...@@ -367,8 +409,13 @@ namespace dlib ...@@ -367,8 +409,13 @@ namespace dlib
tt::batch_normalize_gradient bng; tt::batch_normalize_gradient bng;
resizable_tensor params; resizable_tensor params;
alias_tensor gamma, beta; alias_tensor gamma, beta;
resizable_tensor means; resizable_tensor means, running_means;
resizable_tensor invstds; resizable_tensor invstds, running_invstds;
unsigned long num_updates;
unsigned long running_stats_window_size;
bool running_nim_out_of_date;
resizable_tensor running_nim;
}; };
template <typename SUBNET> template <typename SUBNET>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment