Commit 15fb2527 authored by Davis King's avatar Davis King

Gave the batch normalization layer an automatic testing mode that causes

it to use the saved average mean and invstd to scale the data instead of
normalizing the batch.
parent 2e01920f
......@@ -301,7 +301,7 @@ namespace dlib
class bn_
{
public:
bn_()
bn_() : num_updates(0), running_stats_window_size(1000), running_nim_out_of_date(true)
{}
template <typename SUBNET>
......@@ -317,6 +317,16 @@ namespace dlib
gamma(params,0) = 1;
beta(params,gamma.size()) = 0;
running_means.set_size(1,
sub.get_output().k(),
sub.get_output().nr(),
sub.get_output().nc());
running_invstds.copy_size(running_means);
running_means = 0;
running_invstds = 1;
num_updates = 0;
running_nim_out_of_date = true;
}
template <typename SUBNET>
......@@ -324,7 +334,30 @@ namespace dlib
{
auto g = gamma(params,0);
auto b = beta(params,gamma.size());
if (sub.get_output().num_samples() > 1)
{
tt::batch_normalize(output, means, invstds, sub.get_output(), g, b);
const double decay = num_updates/(num_updates+1.0);
if (num_updates <running_stats_window_size)
++num_updates;
tt::affine_transform(running_means, running_means, means, decay, 1-decay, 0);
tt::affine_transform(running_invstds, running_invstds, invstds, decay, 1-decay, 0);
running_nim_out_of_date = true;
}
else // we are running in testing mode so we just linearly scale the input tensor.
{
if (running_nim_out_of_date)
{
running_nim_out_of_date = false;
running_nim.copy_size(running_means);
tt::multiply(running_nim, running_means, running_invstds);
running_nim *= -1;
}
output.copy_size(sub.get_output());
tt::affine_transform(output, sub.get_output(), running_invstds, running_nim);
tt::affine_transform(output, output, g, b);
}
}
template <typename SUBNET>
......@@ -347,6 +380,10 @@ namespace dlib
serialize(item.beta, out);
serialize(item.means, out);
serialize(item.invstds, out);
serialize(item.running_means, out);
serialize(item.running_invstds, out);
serialize(item.num_updates, out);
serialize(item.running_stats_window_size, out);
}
friend void deserialize(bn_& item, std::istream& in)
......@@ -360,6 +397,11 @@ namespace dlib
deserialize(item.beta, in);
deserialize(item.means, in);
deserialize(item.invstds, in);
deserialize(item.running_means, in);
deserialize(item.running_invstds, in);
deserialize(item.num_updates, in);
deserialize(item.running_stats_window_size, in);
item.running_nim_out_of_date = true;
}
private:
......@@ -367,8 +409,13 @@ namespace dlib
tt::batch_normalize_gradient bng;
resizable_tensor params;
alias_tensor gamma, beta;
resizable_tensor means;
resizable_tensor invstds;
resizable_tensor means, running_means;
resizable_tensor invstds, running_invstds;
unsigned long num_updates;
unsigned long running_stats_window_size;
bool running_nim_out_of_date;
resizable_tensor running_nim;
};
template <typename SUBNET>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment