Commit 61dce40e authored by Davis King's avatar Davis King

Added set_all_bn_running_stats_window_sizes() and also changed the default

batch normalization running stats window size from 1000 to 100.
parent a280e48c
......@@ -718,12 +718,19 @@ namespace dlib
bias_learning_rate_multiplier(1),
bias_weight_decay_multiplier(1),
eps(eps_)
{}
{
DLIB_CASSERT(window_size > 0, "The batch normalization running stats window size can't be 0.");
}
bn_() : bn_(1000) {}
bn_() : bn_(100) {}
layer_mode get_mode() const { return mode; }
unsigned long get_running_stats_window_size () const { return running_stats_window_size; }
void set_running_stats_window_size (unsigned long new_window_size )
{
DLIB_CASSERT(new_window_size > 0, "The batch normalization running stats window size can't be 0.");
running_stats_window_size = new_window_size;
}
double get_eps() const { return eps; }
double get_learning_rate_multiplier () const { return learning_rate_multiplier; }
......@@ -776,8 +783,10 @@ namespace dlib
if (sub.get_output().num_samples() > 1)
{
const double decay = 1.0 - num_updates/(num_updates+1.0);
if (num_updates <running_stats_window_size)
++num_updates;
++num_updates;
if (num_updates > running_stats_window_size)
num_updates = running_stats_window_size;
if (mode == FC_MODE)
tt::batch_normalize(eps, output, means, invstds, decay, running_means, running_variances, sub.get_output(), g, b);
else
......@@ -867,6 +876,7 @@ namespace dlib
else
out << "bn_fc ";
out << " eps="<<item.eps;
out << " running_stats_window_size="<<item.running_stats_window_size;
out << " learning_rate_mult="<<item.learning_rate_multiplier;
out << " weight_decay_mult="<<item.weight_decay_multiplier;
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
......@@ -882,6 +892,7 @@ namespace dlib
out << "<bn_fc";
out << " eps='"<<item.eps<<"'";
out << " running_stats_window_size='"<<item.running_stats_window_size<<"'";
out << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'";
out << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'";
out << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'";
......@@ -918,6 +929,56 @@ namespace dlib
template <typename SUBNET>
using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
// ----------------------------------------------------------------------------------------
namespace impl
{
class visitor_bn_running_stats_window_size
{
public:
visitor_bn_running_stats_window_size(unsigned long new_window_size_) : new_window_size(new_window_size_) {}
template <typename T>
void set_window_size(T&) const
{
// ignore other layer detail types
}
template < layer_mode mode >
void set_window_size(bn_<mode>& l) const
{
l.set_running_stats_window_size(new_window_size);
}
template<typename input_layer_type>
void operator()(size_t , input_layer_type& ) const
{
// ignore other layers
}
template <typename T, typename U, typename E>
void operator()(size_t , add_layer<T,U,E>& l) const
{
set_window_size(l.layer_details());
}
private:
unsigned long new_window_size;
};
}
template <typename net_type>
void set_all_bn_running_stats_window_sizes (
net_type& net,
unsigned long new_window_size
)
{
visit_layers(net, impl::visitor_bn_running_stats_window_size(new_window_size));
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
enum fc_bias_mode
......
......@@ -922,7 +922,7 @@ namespace dlib
/*!
ensures
- #get_mode() == mode
- #get_running_stats_window_size() == 1000
- #get_running_stats_window_size() == 100
- #get_learning_rate_multiplier() == 1
- #get_weight_decay_multiplier() == 0
- #get_bias_learning_rate_multiplier() == 1
......@@ -937,6 +937,7 @@ namespace dlib
/*!
requires
- eps > 0
- window_size > 0
ensures
- #get_mode() == mode
- #get_running_stats_window_size() == window_size
......@@ -985,6 +986,16 @@ namespace dlib
the running average.
!*/
void set_running_stats_window_size (
unsigned long new_window_size
);
/*!
requires
- new_window_size > 0
ensures
- #get_running_stats_window_size() == new_window_size
!*/
double get_learning_rate_multiplier(
) const;
/*!
......@@ -1078,6 +1089,23 @@ namespace dlib
template <typename SUBNET>
using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
// ----------------------------------------------------------------------------------------
template <typename net_type>
void set_all_bn_running_stats_window_sizes (
const net_type& net,
unsigned long new_window_size
);
/*!
requires
- new_window_size > 0
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
add_tag_layer.
ensures
- Sets the get_running_stats_window_size() field of all bn_ layers in net to
new_window_size.
!*/
// ----------------------------------------------------------------------------------------
class affine_
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment