Commit e5ad9590 authored by Davis King's avatar Davis King

Added bias learning rate and weight decay multipliers to bn_ layers

parent b6b83798
...@@ -666,6 +666,8 @@ namespace dlib ...@@ -666,6 +666,8 @@ namespace dlib
running_stats_window_size(window_size), running_stats_window_size(window_size),
learning_rate_multiplier(1), learning_rate_multiplier(1),
weight_decay_multiplier(0), weight_decay_multiplier(0),
bias_learning_rate_multiplier(1),
bias_weight_decay_multiplier(1),
eps(eps_) eps(eps_)
{} {}
...@@ -680,6 +682,11 @@ namespace dlib ...@@ -680,6 +682,11 @@ namespace dlib
void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; } void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; } void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; }
double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; }
double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
template <typename SUBNET> template <typename SUBNET>
void setup (const SUBNET& sub) void setup (const SUBNET& sub)
...@@ -765,6 +772,8 @@ namespace dlib ...@@ -765,6 +772,8 @@ namespace dlib
serialize(item.running_stats_window_size, out); serialize(item.running_stats_window_size, out);
serialize(item.learning_rate_multiplier, out); serialize(item.learning_rate_multiplier, out);
serialize(item.weight_decay_multiplier, out); serialize(item.weight_decay_multiplier, out);
serialize(item.bias_learning_rate_multiplier, out);
serialize(item.bias_weight_decay_multiplier, out);
serialize(item.eps, out); serialize(item.eps, out);
} }
...@@ -812,6 +821,8 @@ namespace dlib ...@@ -812,6 +821,8 @@ namespace dlib
{ {
deserialize(item.learning_rate_multiplier, in); deserialize(item.learning_rate_multiplier, in);
deserialize(item.weight_decay_multiplier, in); deserialize(item.weight_decay_multiplier, in);
deserialize(item.bias_learning_rate_multiplier, in);
deserialize(item.bias_weight_decay_multiplier, in);
deserialize(item.eps, in); deserialize(item.eps, in);
} }
else else
...@@ -834,6 +845,8 @@ namespace dlib ...@@ -834,6 +845,8 @@ namespace dlib
out << " eps="<<item.eps; out << " eps="<<item.eps;
out << " learning_rate_mult="<<item.learning_rate_multiplier; out << " learning_rate_mult="<<item.learning_rate_multiplier;
out << " weight_decay_mult="<<item.weight_decay_multiplier; out << " weight_decay_mult="<<item.weight_decay_multiplier;
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
return out; return out;
} }
...@@ -849,6 +862,8 @@ namespace dlib ...@@ -849,6 +862,8 @@ namespace dlib
unsigned long running_stats_window_size; unsigned long running_stats_window_size;
double learning_rate_multiplier; double learning_rate_multiplier;
double weight_decay_multiplier; double weight_decay_multiplier;
double bias_learning_rate_multiplier;
double bias_weight_decay_multiplier;
double eps; double eps;
}; };
......
...@@ -859,6 +859,8 @@ namespace dlib ...@@ -859,6 +859,8 @@ namespace dlib
- #get_running_stats_window_size() == 1000 - #get_running_stats_window_size() == 1000
- #get_learning_rate_multiplier() == 1 - #get_learning_rate_multiplier() == 1
- #get_weight_decay_multiplier() == 0 - #get_weight_decay_multiplier() == 0
- #get_bias_learning_rate_multiplier() == 1
- #get_bias_weight_decay_multiplier() == 1
- #get_eps() == tt::DEFAULT_BATCH_NORM_EPS - #get_eps() == tt::DEFAULT_BATCH_NORM_EPS
!*/ !*/
...@@ -874,6 +876,8 @@ namespace dlib ...@@ -874,6 +876,8 @@ namespace dlib
- #get_running_stats_window_size() == window_size - #get_running_stats_window_size() == window_size
- #get_learning_rate_multiplier() == 1 - #get_learning_rate_multiplier() == 1
- #get_weight_decay_multiplier() == 0 - #get_weight_decay_multiplier() == 0
- #get_bias_learning_rate_multiplier() == 1
- #get_bias_weight_decay_multiplier() == 1
- #get_eps() == eps - #get_eps() == eps
!*/ !*/
...@@ -953,6 +957,44 @@ namespace dlib ...@@ -953,6 +957,44 @@ namespace dlib
- #get_weight_decay_multiplier() == val - #get_weight_decay_multiplier() == val
!*/ !*/
double get_bias_learning_rate_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the learning rate used to optimize its bias parameters be
multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier().
!*/
double get_bias_weight_decay_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the weight decay used to optimize its bias parameters be
multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier().
!*/
void set_bias_learning_rate_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_bias_learning_rate_multiplier() == val
!*/
void set_bias_weight_decay_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_bias_weight_decay_multiplier() == val
!*/
template <typename SUBNET> void setup (const SUBNET& sub); template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output); template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
......
...@@ -89,6 +89,17 @@ namespace dlib ...@@ -89,6 +89,17 @@ namespace dlib
return v; return v;
} }
template < layer_mode mode >
const tensor& operator() (
const float learning_rate,
const bn_<mode>& l,
const tensor& params_grad
)
{
update_considering_bias(learning_rate, l, params_grad, params_grad.size()/2);
return v;
}
friend void serialize(const sgd& item, std::ostream& out) friend void serialize(const sgd& item, std::ostream& out)
{ {
serialize("sgd2", out); serialize("sgd2", out);
...@@ -244,6 +255,17 @@ namespace dlib ...@@ -244,6 +255,17 @@ namespace dlib
return s; return s;
} }
template < layer_mode mode >
const tensor& operator() (
const float learning_rate,
const bn_<mode>& l,
const tensor& params_grad
)
{
update_considering_bias(learning_rate, l, params_grad, params_grad.size()/2);
return s;
}
friend void serialize(const adam& item, std::ostream& out) friend void serialize(const adam& item, std::ostream& out)
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment