Commit 61dce40e authored by Davis King's avatar Davis King

Added set_all_bn_running_stats_window_sizes() and also changed the default

batch normalization running stats window size from 1000 to 100.
parent a280e48c
...@@ -718,12 +718,19 @@ namespace dlib ...@@ -718,12 +718,19 @@ namespace dlib
bias_learning_rate_multiplier(1), bias_learning_rate_multiplier(1),
bias_weight_decay_multiplier(1), bias_weight_decay_multiplier(1),
eps(eps_) eps(eps_)
{} {
DLIB_CASSERT(window_size > 0, "The batch normalization running stats window size can't be 0.");
}
bn_() : bn_(1000) {} bn_() : bn_(100) {}
layer_mode get_mode() const { return mode; } layer_mode get_mode() const { return mode; }
unsigned long get_running_stats_window_size () const { return running_stats_window_size; } unsigned long get_running_stats_window_size () const { return running_stats_window_size; }
void set_running_stats_window_size (unsigned long new_window_size )
{
DLIB_CASSERT(new_window_size > 0, "The batch normalization running stats window size can't be 0.");
running_stats_window_size = new_window_size;
}
double get_eps() const { return eps; } double get_eps() const { return eps; }
double get_learning_rate_multiplier () const { return learning_rate_multiplier; } double get_learning_rate_multiplier () const { return learning_rate_multiplier; }
...@@ -776,8 +783,10 @@ namespace dlib ...@@ -776,8 +783,10 @@ namespace dlib
if (sub.get_output().num_samples() > 1) if (sub.get_output().num_samples() > 1)
{ {
const double decay = 1.0 - num_updates/(num_updates+1.0); const double decay = 1.0 - num_updates/(num_updates+1.0);
if (num_updates <running_stats_window_size) ++num_updates;
++num_updates; if (num_updates > running_stats_window_size)
num_updates = running_stats_window_size;
if (mode == FC_MODE) if (mode == FC_MODE)
tt::batch_normalize(eps, output, means, invstds, decay, running_means, running_variances, sub.get_output(), g, b); tt::batch_normalize(eps, output, means, invstds, decay, running_means, running_variances, sub.get_output(), g, b);
else else
...@@ -867,6 +876,7 @@ namespace dlib ...@@ -867,6 +876,7 @@ namespace dlib
else else
out << "bn_fc "; out << "bn_fc ";
out << " eps="<<item.eps; out << " eps="<<item.eps;
out << " running_stats_window_size="<<item.running_stats_window_size;
out << " learning_rate_mult="<<item.learning_rate_multiplier; out << " learning_rate_mult="<<item.learning_rate_multiplier;
out << " weight_decay_mult="<<item.weight_decay_multiplier; out << " weight_decay_mult="<<item.weight_decay_multiplier;
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier; out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
...@@ -882,6 +892,7 @@ namespace dlib ...@@ -882,6 +892,7 @@ namespace dlib
out << "<bn_fc"; out << "<bn_fc";
out << " eps='"<<item.eps<<"'"; out << " eps='"<<item.eps<<"'";
out << " running_stats_window_size='"<<item.running_stats_window_size<<"'";
out << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"; out << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'";
out << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"; out << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'";
out << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"; out << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'";
...@@ -918,6 +929,56 @@ namespace dlib ...@@ -918,6 +929,56 @@ namespace dlib
template <typename SUBNET> template <typename SUBNET>
using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>; using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
// ----------------------------------------------------------------------------------------
namespace impl
{
class visitor_bn_running_stats_window_size
{
public:
visitor_bn_running_stats_window_size(unsigned long new_window_size_) : new_window_size(new_window_size_) {}
template <typename T>
void set_window_size(T&) const
{
// ignore other layer detail types
}
template < layer_mode mode >
void set_window_size(bn_<mode>& l) const
{
l.set_running_stats_window_size(new_window_size);
}
template<typename input_layer_type>
void operator()(size_t , input_layer_type& ) const
{
// ignore other layers
}
template <typename T, typename U, typename E>
void operator()(size_t , add_layer<T,U,E>& l) const
{
set_window_size(l.layer_details());
}
private:
unsigned long new_window_size;
};
}
template <typename net_type>
void set_all_bn_running_stats_window_sizes (
net_type& net,
unsigned long new_window_size
)
{
visit_layers(net, impl::visitor_bn_running_stats_window_size(new_window_size));
}
// ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
enum fc_bias_mode enum fc_bias_mode
......
...@@ -922,7 +922,7 @@ namespace dlib ...@@ -922,7 +922,7 @@ namespace dlib
/*! /*!
ensures ensures
- #get_mode() == mode - #get_mode() == mode
- #get_running_stats_window_size() == 1000 - #get_running_stats_window_size() == 100
- #get_learning_rate_multiplier() == 1 - #get_learning_rate_multiplier() == 1
- #get_weight_decay_multiplier() == 0 - #get_weight_decay_multiplier() == 0
- #get_bias_learning_rate_multiplier() == 1 - #get_bias_learning_rate_multiplier() == 1
...@@ -937,6 +937,7 @@ namespace dlib ...@@ -937,6 +937,7 @@ namespace dlib
/*! /*!
requires requires
- eps > 0 - eps > 0
- window_size > 0
ensures ensures
- #get_mode() == mode - #get_mode() == mode
- #get_running_stats_window_size() == window_size - #get_running_stats_window_size() == window_size
...@@ -985,6 +986,16 @@ namespace dlib ...@@ -985,6 +986,16 @@ namespace dlib
the running average. the running average.
!*/ !*/
void set_running_stats_window_size (
unsigned long new_window_size
);
/*!
requires
- new_window_size > 0
ensures
- #get_running_stats_window_size() == new_window_size
!*/
double get_learning_rate_multiplier( double get_learning_rate_multiplier(
) const; ) const;
/*! /*!
...@@ -1078,6 +1089,23 @@ namespace dlib ...@@ -1078,6 +1089,23 @@ namespace dlib
template <typename SUBNET> template <typename SUBNET>
using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>; using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
// ----------------------------------------------------------------------------------------
template <typename net_type>
void set_all_bn_running_stats_window_sizes (
const net_type& net,
unsigned long new_window_size
);
/*!
requires
- new_window_size > 0
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
add_tag_layer.
ensures
- Sets the get_running_stats_window_size() field of all bn_ layers in net to
new_window_size.
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class affine_ class affine_
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment