Commit 7550681b authored by Davis King's avatar Davis King

Implemented the bn layer.

parent 363b6b2f
......@@ -187,27 +187,68 @@ namespace dlib
template <typename SUBNET>
void setup (const SUBNET& sub)
{
// TODO
gamma = alias_tensor(1,
sub.get_output().k(),
sub.get_output().nr(),
sub.get_output().nc());
beta = gamma;
params.set_size(gamma.size()+beta.size());
gamma(params,0) = 1;
beta(params,gamma.size()) = 0;
}
template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output)
{
// TODO
auto g = gamma(params,0);
auto b = beta(params,gamma.size());
tt::batch_normalize(output, means, invstds, sub.get_output(), g, b);
}
template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
{
// TODO
auto g = gamma(params,0);
auto g_grad = gamma(params_grad, 0);
auto b_grad = beta(params_grad, gamma.size());
bng(gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad);
}
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
friend void serialize(const bn_& item, std::ostream& out)
{
serialize("bn_", out);
serialize(item.params, out);
serialize(item.gamma, out);
serialize(item.beta, out);
serialize(item.means, out);
serialize(item.invstds, out);
}
friend void deserialize(bn_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "bn_")
throw serialization_error("Unexpected version found while deserializing dlib::bn_.");
deserialize(item.params, in);
deserialize(item.gamma, in);
deserialize(item.beta, in);
deserialize(item.means, in);
deserialize(item.invstds, in);
}
private:
tt::batch_normalize_gradient bng;
resizable_tensor params;
alias_tensor gamma, beta;
resizable_tensor means;
resizable_tensor invstds;
};
template <typename SUBNET>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment