Commit b92b226c authored by Davis King's avatar Davis King

Added learning rate and weight decay multipliers to the con_, fc_, and bn_

layers.  Updated the solvers to support this.
parent 40f04beb
...@@ -488,6 +488,8 @@ namespace dlib ...@@ -488,6 +488,8 @@ namespace dlib
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void compute_adam_update ( void compute_adam_update (
size_t begin,
size_t end,
tensor& s, tensor& s,
tensor& m, tensor& m,
tensor& v, tensor& v,
...@@ -504,6 +506,7 @@ namespace dlib ...@@ -504,6 +506,7 @@ namespace dlib
s.size() == v.size() && s.size() == v.size() &&
s.size() == params.size() && s.size() == params.size() &&
s.size() == params_grad.size(),""); s.size() == params_grad.size(),"");
DLIB_CASSERT(begin <= end && end <= params.size(),"");
const float eps = 1e-8; const float eps = 1e-8;
const float alpha = learning_rate*std::sqrt(1-std::pow(momentum2,t))/(1-std::pow(momentum1, t)); const float alpha = learning_rate*std::sqrt(1-std::pow(momentum2,t))/(1-std::pow(momentum1, t));
...@@ -516,7 +519,7 @@ namespace dlib ...@@ -516,7 +519,7 @@ namespace dlib
auto ps = s.host_write_only(); auto ps = s.host_write_only();
auto pparams = params.host(); auto pparams = params.host();
auto ppgrad = params_grad.host(); auto ppgrad = params_grad.host();
for (size_t i = 0; i < params.size(); ++i) for (size_t i = begin; i < end; ++i)
{ {
float g = weight_decay*pparams[i] + ppgrad[i]; float g = weight_decay*pparams[i] + ppgrad[i];
pm[i] = momentum1*pm[i] + (1-momentum1)*g; pm[i] = momentum1*pm[i] + (1-momentum1)*g;
......
...@@ -114,6 +114,8 @@ namespace dlib ...@@ -114,6 +114,8 @@ namespace dlib
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
void compute_adam_update ( void compute_adam_update (
size_t begin,
size_t end,
tensor& s, tensor& s,
tensor& m, tensor& m,
tensor& v, tensor& v,
......
...@@ -583,7 +583,8 @@ namespace dlib ...@@ -583,7 +583,8 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
__global__ void _cuda_compute_adam_update( __global__ void _cuda_compute_adam_update(
size_t n, size_t begin,
size_t end,
float* s, float* s,
float* m, float* m,
float* v, float* v,
...@@ -600,7 +601,7 @@ namespace dlib ...@@ -600,7 +601,7 @@ namespace dlib
// m = momentum1*m + (1-momentum1) * (weight_decay*params + params_grad); // m = momentum1*m + (1-momentum1) * (weight_decay*params + params_grad);
// v = momentum2*v + (1-momentum2)*squared(weight_decay*params + params_grad); // v = momentum2*v + (1-momentum2)*squared(weight_decay*params + params_grad);
// s = -alpha*m/(sqrt(v) + eps); // s = -alpha*m/(sqrt(v) + eps);
for (auto i : grid_stride_range(0, n)) for (auto i : grid_stride_range(begin, end))
{ {
float g = (weight_decay*params[i] + params_grad[i]); float g = (weight_decay*params[i] + params_grad[i]);
m[i] = momentum1*m[i] + (1-momentum1)*g; m[i] = momentum1*m[i] + (1-momentum1)*g;
...@@ -610,6 +611,8 @@ namespace dlib ...@@ -610,6 +611,8 @@ namespace dlib
} }
void compute_adam_update ( void compute_adam_update (
size_t begin,
size_t end,
tensor& s, tensor& s,
tensor& m, tensor& m,
tensor& v, tensor& v,
...@@ -626,10 +629,11 @@ namespace dlib ...@@ -626,10 +629,11 @@ namespace dlib
s.size() == v.size() && s.size() == v.size() &&
s.size() == params.size() && s.size() == params.size() &&
s.size() == params_grad.size(),""); s.size() == params_grad.size(),"");
DLIB_CASSERT(begin <= end && end <= params.size(),"");
const float alpha = learning_rate*std::sqrt(1-std::pow(momentum2,t))/(1-std::pow(momentum1, t)); const float alpha = learning_rate*std::sqrt(1-std::pow(momentum2,t))/(1-std::pow(momentum1, t));
launch_kernel(_cuda_compute_adam_update,max_jobs(s.size()), launch_kernel(_cuda_compute_adam_update,max_jobs(end-begin),
s.size(), s.device(), m.device(), v.device(), alpha, weight_decay, begin, end, s.device(), m.device(), v.device(), alpha, weight_decay,
momentum1, momentum2, params.device(), params_grad.device()); momentum1, momentum2, params.device(), params_grad.device());
} }
......
...@@ -205,6 +205,8 @@ namespace dlib ...@@ -205,6 +205,8 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void compute_adam_update ( void compute_adam_update (
size_t begin,
size_t end,
tensor& s, tensor& s,
tensor& m, tensor& m,
tensor& v, tensor& v,
......
...@@ -42,6 +42,10 @@ namespace dlib ...@@ -42,6 +42,10 @@ namespace dlib
con_( con_(
) : ) :
learning_rate_multiplier(1),
weight_decay_multiplier(1),
bias_learning_rate_multiplier(1),
bias_weight_decay_multiplier(0),
padding_y_(_padding_y), padding_y_(_padding_y),
padding_x_(_padding_x) padding_x_(_padding_x)
{} {}
...@@ -54,12 +58,27 @@ namespace dlib ...@@ -54,12 +58,27 @@ namespace dlib
long padding_y() const { return padding_y_; } long padding_y() const { return padding_y_; }
long padding_x() const { return padding_x_; } long padding_x() const { return padding_x_; }
double get_learning_rate_multiplier () const { return learning_rate_multiplier; }
double get_weight_decay_multiplier () const { return weight_decay_multiplier; }
void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; }
double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; }
double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
con_ ( con_ (
const con_& item const con_& item
) : ) :
params(item.params), params(item.params),
filters(item.filters), filters(item.filters),
biases(item.biases), biases(item.biases),
learning_rate_multiplier(item.learning_rate_multiplier),
weight_decay_multiplier(item.weight_decay_multiplier),
bias_learning_rate_multiplier(item.bias_learning_rate_multiplier),
bias_weight_decay_multiplier(item.bias_weight_decay_multiplier),
padding_y_(item.padding_y_), padding_y_(item.padding_y_),
padding_x_(item.padding_x_) padding_x_(item.padding_x_)
{ {
...@@ -81,6 +100,10 @@ namespace dlib ...@@ -81,6 +100,10 @@ namespace dlib
biases = item.biases; biases = item.biases;
padding_y_ = item.padding_y_; padding_y_ = item.padding_y_;
padding_x_ = item.padding_x_; padding_x_ = item.padding_x_;
learning_rate_multiplier = item.learning_rate_multiplier;
weight_decay_multiplier = item.weight_decay_multiplier;
bias_learning_rate_multiplier = item.bias_learning_rate_multiplier;
bias_weight_decay_multiplier = item.bias_weight_decay_multiplier;
return *this; return *this;
} }
...@@ -121,18 +144,22 @@ namespace dlib ...@@ -121,18 +144,22 @@ namespace dlib
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
{ {
conv.get_gradient_for_data (gradient_input, filters(params,0), sub.get_gradient_input()); conv.get_gradient_for_data (gradient_input, filters(params,0), sub.get_gradient_input());
// no point computing the parameter gradients if they won't be used.
if (learning_rate_multiplier != 0)
{
auto filt = filters(params_grad,0); auto filt = filters(params_grad,0);
conv.get_gradient_for_filters (gradient_input, sub.get_output(), filt); conv.get_gradient_for_filters (gradient_input, sub.get_output(), filt);
auto b = biases(params_grad, filters.size()); auto b = biases(params_grad, filters.size());
tt::assign_conv_bias_gradient(b, gradient_input); tt::assign_conv_bias_gradient(b, gradient_input);
} }
}
const tensor& get_layer_params() const { return params; } const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; } tensor& get_layer_params() { return params; }
friend void serialize(const con_& item, std::ostream& out) friend void serialize(const con_& item, std::ostream& out)
{ {
serialize("con_2", out); serialize("con_3", out);
serialize(item.params, out); serialize(item.params, out);
serialize(_num_filters, out); serialize(_num_filters, out);
serialize(_nr, out); serialize(_nr, out);
...@@ -143,6 +170,10 @@ namespace dlib ...@@ -143,6 +170,10 @@ namespace dlib
serialize(item.padding_x_, out); serialize(item.padding_x_, out);
serialize(item.filters, out); serialize(item.filters, out);
serialize(item.biases, out); serialize(item.biases, out);
serialize(item.learning_rate_multiplier, out);
serialize(item.weight_decay_multiplier, out);
serialize(item.bias_learning_rate_multiplier, out);
serialize(item.bias_weight_decay_multiplier, out);
} }
friend void deserialize(con_& item, std::istream& in) friend void deserialize(con_& item, std::istream& in)
...@@ -167,7 +198,7 @@ namespace dlib ...@@ -167,7 +198,7 @@ namespace dlib
item.padding_y_ = nr/2; item.padding_y_ = nr/2;
item.padding_x_ = nc/2; item.padding_x_ = nc/2;
} }
else if (version == "con_2") else if (version == "con_2" || version == "con_3")
{ {
deserialize(item.params, in); deserialize(item.params, in);
deserialize(num_filters, in); deserialize(num_filters, in);
...@@ -180,6 +211,23 @@ namespace dlib ...@@ -180,6 +211,23 @@ namespace dlib
deserialize(item.filters, in); deserialize(item.filters, in);
deserialize(item.biases, in); deserialize(item.biases, in);
if (version == "con_3")
{
deserialize(item.learning_rate_multiplier, in);
deserialize(item.weight_decay_multiplier, in);
deserialize(item.bias_learning_rate_multiplier, in);
deserialize(item.bias_weight_decay_multiplier, in);
}
else
{
// Previous versions didn't have these parameters, so they were
// implicitly 1.
item.learning_rate_multiplier = 1;
item.weight_decay_multiplier = 1;
item.bias_learning_rate_multiplier = 1;
item.bias_weight_decay_multiplier = 1;
}
if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_"); if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_");
if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_"); if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_");
} }
...@@ -207,6 +255,10 @@ namespace dlib ...@@ -207,6 +255,10 @@ namespace dlib
<< ", padding_y="<<item.padding_y_ << ", padding_y="<<item.padding_y_
<< ", padding_x="<<item.padding_x_ << ", padding_x="<<item.padding_x_
<< ")"; << ")";
out << " learning_rate_mult="<<item.learning_rate_multiplier;
out << " weight_decay_mult="<<item.weight_decay_multiplier;
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
return out; return out;
} }
...@@ -217,6 +269,10 @@ namespace dlib ...@@ -217,6 +269,10 @@ namespace dlib
alias_tensor filters, biases; alias_tensor filters, biases;
tt::tensor_conv conv; tt::tensor_conv conv;
double learning_rate_multiplier;
double weight_decay_multiplier;
double bias_learning_rate_multiplier;
double bias_weight_decay_multiplier;
// These are here only because older versions of con (which you might encounter // These are here only because older versions of con (which you might encounter
// serialized to disk) used different padding settings. // serialized to disk) used different padding settings.
...@@ -600,15 +656,24 @@ namespace dlib ...@@ -600,15 +656,24 @@ namespace dlib
class bn_ class bn_
{ {
public: public:
bn_() : num_updates(0), running_stats_window_size(1000) explicit bn_(unsigned long window_size) :
num_updates(0),
running_stats_window_size(window_size),
learning_rate_multiplier(1),
weight_decay_multiplier(0)
{} {}
explicit bn_(unsigned long window_size) : num_updates(0), running_stats_window_size(window_size) bn_() : bn_(1000) {}
{}
layer_mode get_mode() const { return mode; } layer_mode get_mode() const { return mode; }
unsigned long get_running_stats_window_size () const { return running_stats_window_size; } unsigned long get_running_stats_window_size () const { return running_stats_window_size; }
double get_learning_rate_multiplier () const { return learning_rate_multiplier; }
double get_weight_decay_multiplier () const { return weight_decay_multiplier; }
void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; }
template <typename SUBNET> template <typename SUBNET>
void setup (const SUBNET& sub) void setup (const SUBNET& sub)
{ {
...@@ -679,9 +744,9 @@ namespace dlib ...@@ -679,9 +744,9 @@ namespace dlib
friend void serialize(const bn_& item, std::ostream& out) friend void serialize(const bn_& item, std::ostream& out)
{ {
if (mode == CONV_MODE) if (mode == CONV_MODE)
serialize("bn_con", out); serialize("bn_con2", out);
else // if FC_MODE else // if FC_MODE
serialize("bn_fc", out); serialize("bn_fc2", out);
serialize(item.params, out); serialize(item.params, out);
serialize(item.gamma, out); serialize(item.gamma, out);
serialize(item.beta, out); serialize(item.beta, out);
...@@ -691,6 +756,8 @@ namespace dlib ...@@ -691,6 +756,8 @@ namespace dlib
serialize(item.running_variances, out); serialize(item.running_variances, out);
serialize(item.num_updates, out); serialize(item.num_updates, out);
serialize(item.running_stats_window_size, out); serialize(item.running_stats_window_size, out);
serialize(item.learning_rate_multiplier, out);
serialize(item.weight_decay_multiplier, out);
} }
friend void deserialize(bn_& item, std::istream& in) friend void deserialize(bn_& item, std::istream& in)
...@@ -701,12 +768,12 @@ namespace dlib ...@@ -701,12 +768,12 @@ namespace dlib
{ {
if (mode == CONV_MODE) if (mode == CONV_MODE)
{ {
if (version != "bn_con") if (version != "bn_con" && version != "bn_con2")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
} }
else // must be in FC_MODE else // must be in FC_MODE
{ {
if (version != "bn_fc") if (version != "bn_fc" && version != "bn_fc2")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
} }
} }
...@@ -733,14 +800,28 @@ namespace dlib ...@@ -733,14 +800,28 @@ namespace dlib
// format saved the inverse standard deviations instead of variances. // format saved the inverse standard deviations instead of variances.
item.running_variances = 1.0f/squared(mat(item.running_variances)) - tt::BATCH_NORM_EPS; item.running_variances = 1.0f/squared(mat(item.running_variances)) - tt::BATCH_NORM_EPS;
} }
else if (version == "bn_con2" || version == "bn_fc2")
{
deserialize(item.learning_rate_multiplier, in);
deserialize(item.weight_decay_multiplier, in);
}
else
{
// Previous versions didn't have these parameters, so they were
// implicitly 1.
item.learning_rate_multiplier = 1;
item.weight_decay_multiplier = 1;
}
} }
friend std::ostream& operator<<(std::ostream& out, const bn_& item) friend std::ostream& operator<<(std::ostream& out, const bn_& item)
{ {
if (mode == CONV_MODE) if (mode == CONV_MODE)
out << "bn_con"; out << "bn_con ";
else else
out << "bn_fc"; out << "bn_fc ";
out << " learning_rate_mult="<<item.learning_rate_multiplier;
out << " weight_decay_mult="<<item.weight_decay_multiplier;
return out; return out;
} }
...@@ -754,6 +835,8 @@ namespace dlib ...@@ -754,6 +835,8 @@ namespace dlib
resizable_tensor invstds, running_variances; resizable_tensor invstds, running_variances;
unsigned long num_updates; unsigned long num_updates;
unsigned long running_stats_window_size; unsigned long running_stats_window_size;
double learning_rate_multiplier;
double weight_decay_multiplier;
}; };
template <typename SUBNET> template <typename SUBNET>
...@@ -784,11 +867,24 @@ namespace dlib ...@@ -784,11 +867,24 @@ namespace dlib
static_assert(num_outputs_ > 0, "The number of outputs from a fc_ layer must be > 0"); static_assert(num_outputs_ > 0, "The number of outputs from a fc_ layer must be > 0");
public: public:
fc_() : num_outputs(num_outputs_), num_inputs(0) fc_(num_fc_outputs o) : num_outputs(o.num_outputs), num_inputs(0),
{ learning_rate_multiplier(1),
} weight_decay_multiplier(1),
bias_learning_rate_multiplier(1),
bias_weight_decay_multiplier(0)
{}
fc_() : fc_(num_fc_outputs(num_outputs_)) {}
double get_learning_rate_multiplier () const { return learning_rate_multiplier; }
double get_weight_decay_multiplier () const { return weight_decay_multiplier; }
void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
void set_weight_decay_multiplier(double val) { weight_decay_multiplier = val; }
fc_(num_fc_outputs o) : num_outputs(o.num_outputs), num_inputs(0) {} double get_bias_learning_rate_multiplier () const { return bias_learning_rate_multiplier; }
double get_bias_weight_decay_multiplier () const { return bias_weight_decay_multiplier; }
void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
void set_bias_weight_decay_multiplier(double val) { bias_weight_decay_multiplier = val; }
unsigned long get_num_outputs ( unsigned long get_num_outputs (
) const { return num_outputs; } ) const { return num_outputs; }
...@@ -834,6 +930,9 @@ namespace dlib ...@@ -834,6 +930,9 @@ namespace dlib
template <typename SUBNET> template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad) void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
{
// no point computing the parameter gradients if they won't be used.
if (learning_rate_multiplier != 0)
{ {
// compute the gradient of the weight parameters. // compute the gradient of the weight parameters.
auto pw = weights(params_grad, 0); auto pw = weights(params_grad, 0);
...@@ -845,6 +944,7 @@ namespace dlib ...@@ -845,6 +944,7 @@ namespace dlib
auto pb = biases(params_grad, weights.size()); auto pb = biases(params_grad, weights.size());
tt::assign_bias_gradient(pb, gradient_input); tt::assign_bias_gradient(pb, gradient_input);
} }
}
// compute the gradient for the data // compute the gradient for the data
auto w = weights(params, 0); auto w = weights(params, 0);
...@@ -856,20 +956,24 @@ namespace dlib ...@@ -856,20 +956,24 @@ namespace dlib
friend void serialize(const fc_& item, std::ostream& out) friend void serialize(const fc_& item, std::ostream& out)
{ {
serialize("fc_", out); serialize("fc_2", out);
serialize(item.num_outputs, out); serialize(item.num_outputs, out);
serialize(item.num_inputs, out); serialize(item.num_inputs, out);
serialize(item.params, out); serialize(item.params, out);
serialize(item.weights, out); serialize(item.weights, out);
serialize(item.biases, out); serialize(item.biases, out);
serialize((int)bias_mode, out); serialize((int)bias_mode, out);
serialize(item.learning_rate_multiplier, out);
serialize(item.weight_decay_multiplier, out);
serialize(item.bias_learning_rate_multiplier, out);
serialize(item.bias_weight_decay_multiplier, out);
} }
friend void deserialize(fc_& item, std::istream& in) friend void deserialize(fc_& item, std::istream& in)
{ {
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "fc_") if (version != "fc_" && version != "fc_2")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::fc_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::fc_.");
deserialize(item.num_outputs, in); deserialize(item.num_outputs, in);
...@@ -880,6 +984,22 @@ namespace dlib ...@@ -880,6 +984,22 @@ namespace dlib
int bmode = 0; int bmode = 0;
deserialize(bmode, in); deserialize(bmode, in);
if (bias_mode != (fc_bias_mode)bmode) throw serialization_error("Wrong fc_bias_mode found while deserializing dlib::fc_"); if (bias_mode != (fc_bias_mode)bmode) throw serialization_error("Wrong fc_bias_mode found while deserializing dlib::fc_");
if (version == "fc_2")
{
deserialize(item.learning_rate_multiplier, in);
deserialize(item.weight_decay_multiplier, in);
deserialize(item.bias_learning_rate_multiplier, in);
deserialize(item.bias_weight_decay_multiplier, in);
}
else
{
// Previous versions didn't have these parameters, so they were
// implicitly 1.
item.learning_rate_multiplier = 1;
item.weight_decay_multiplier = 1;
item.bias_learning_rate_multiplier = 1;
item.bias_weight_decay_multiplier = 1;
}
} }
friend std::ostream& operator<<(std::ostream& out, const fc_& item) friend std::ostream& operator<<(std::ostream& out, const fc_& item)
...@@ -889,12 +1009,18 @@ namespace dlib ...@@ -889,12 +1009,18 @@ namespace dlib
out << "fc\t (" out << "fc\t ("
<< "num_outputs="<<item.num_outputs << "num_outputs="<<item.num_outputs
<< ")"; << ")";
out << " learning_rate_mult="<<item.learning_rate_multiplier;
out << " weight_decay_mult="<<item.weight_decay_multiplier;
out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
} }
else else
{ {
out << "fc_no_bias (" out << "fc_no_bias ("
<< "num_outputs="<<item.num_outputs << "num_outputs="<<item.num_outputs
<< ")"; << ")";
out << " learning_rate_mult="<<item.learning_rate_multiplier;
out << " weight_decay_mult="<<item.weight_decay_multiplier;
} }
return out; return out;
} }
...@@ -905,6 +1031,10 @@ namespace dlib ...@@ -905,6 +1031,10 @@ namespace dlib
unsigned long num_inputs; unsigned long num_inputs;
resizable_tensor params; resizable_tensor params;
alias_tensor weights, biases; alias_tensor weights, biases;
double learning_rate_multiplier;
double weight_decay_multiplier;
double bias_learning_rate_multiplier;
double bias_weight_decay_multiplier;
}; };
template < template <
...@@ -1223,7 +1353,7 @@ namespace dlib ...@@ -1223,7 +1353,7 @@ namespace dlib
{ {
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version == "bn_con") if (version == "bn_con" || version == "bn_con2")
{ {
// Since we can build an affine_ from a bn_ we check if that's what is in // Since we can build an affine_ from a bn_ we check if that's what is in
// the stream and if so then just convert it right here. // the stream and if so then just convert it right here.
...@@ -1233,7 +1363,7 @@ namespace dlib ...@@ -1233,7 +1363,7 @@ namespace dlib
item = temp; item = temp;
return; return;
} }
else if (version == "bn_fc") else if (version == "bn_fc" || version == "bn_fc2")
{ {
// Since we can build an affine_ from a bn_ we check if that's what is in // Since we can build an affine_ from a bn_ we check if that's what is in
// the stream and if so then just convert it right here. // the stream and if so then just convert it right here.
......
...@@ -123,6 +123,16 @@ namespace dlib ...@@ -123,6 +123,16 @@ namespace dlib
allow dlib to make some layers execute in-place and therefore run a allow dlib to make some layers execute in-place and therefore run a
little faster and use less memory. Do not implement forward() and little faster and use less memory. Do not implement forward() and
backward(). backward().
It should also be noted that layers may define additional layer specific
fields and the solvers can use these fields as they see fit. For example,
some layers define get_learning_rate_multiplier() and
get_weight_decay_multiplier() methods. The solvers that come with dlib
look at these methods, if they exist, and adjust the learning rate or
weight decay for that layer according to the multiplier. Therefore, you
can add these methods to your layer types if you want, or even define new
fields and new solvers that use those fields in some way.
!*/ !*/
public: public:
...@@ -367,6 +377,10 @@ namespace dlib ...@@ -367,6 +377,10 @@ namespace dlib
ensures ensures
- #get_num_outputs() == num_outputs - #get_num_outputs() == num_outputs
- #get_bias_mode() == bias_mode - #get_bias_mode() == bias_mode
- #get_learning_rate_multiplier() == 1
- #get_weight_decay_multiplier() == 1
- #get_bias_learning_rate_multiplier() == 1
- #get_bias_weight_decay_multiplier() == 0
!*/ !*/
unsigned long get_num_outputs ( unsigned long get_num_outputs (
...@@ -389,6 +403,82 @@ namespace dlib ...@@ -389,6 +403,82 @@ namespace dlib
is added to each of the outputs of this layer. is added to each of the outputs of this layer.
!*/ !*/
double get_learning_rate_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the learning rate used to optimize its parameters be
multiplied by get_learning_rate_multiplier().
!*/
double get_weight_decay_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the weight decay used to optimize its parameters be
multiplied by get_weight_decay_multiplier().
!*/
void set_learning_rate_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_learning_rate_multiplier() == val
!*/
void set_weight_decay_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_weight_decay_multiplier() == val
!*/
double get_bias_learning_rate_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the learning rate used to optimize its bias parameters be
multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier().
!*/
double get_bias_weight_decay_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the weight decay used to optimize its bias parameters be
multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier().
!*/
void set_bias_learning_rate_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_bias_learning_rate_multiplier() == val
!*/
void set_bias_weight_decay_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_bias_weight_decay_multiplier() == val
!*/
template <typename SUBNET> void setup (const SUBNET& sub); template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output); template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
...@@ -458,6 +548,10 @@ namespace dlib ...@@ -458,6 +548,10 @@ namespace dlib
- #stride_x() == _stride_x - #stride_x() == _stride_x
- #padding_y() == _padding_y - #padding_y() == _padding_y
- #padding_x() == _padding_x - #padding_x() == _padding_x
- #get_learning_rate_multiplier() == 1
- #get_weight_decay_multiplier() == 1
- #get_bias_learning_rate_multiplier() == 1
- #get_bias_weight_decay_multiplier() == 0
!*/ !*/
long num_filters( long num_filters(
...@@ -517,6 +611,82 @@ namespace dlib ...@@ -517,6 +611,82 @@ namespace dlib
sides of the image. sides of the image.
!*/ !*/
double get_learning_rate_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the learning rate used to optimize its parameters be
multiplied by get_learning_rate_multiplier().
!*/
double get_weight_decay_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the weight decay used to optimize its parameters be
multiplied by get_weight_decay_multiplier().
!*/
void set_learning_rate_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_learning_rate_multiplier() == val
!*/
void set_weight_decay_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_weight_decay_multiplier() == val
!*/
double get_bias_learning_rate_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the learning rate used to optimize its bias parameters be
multiplied by get_learning_rate_multiplier()*get_bias_learning_rate_multiplier().
!*/
double get_bias_weight_decay_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the weight decay used to optimize its bias parameters be
multiplied by get_weight_decay_multiplier()*get_bias_weight_decay_multiplier().
!*/
void set_bias_learning_rate_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_bias_learning_rate_multiplier() == val
!*/
void set_bias_weight_decay_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_bias_weight_decay_multiplier() == val
!*/
template <typename SUBNET> void setup (const SUBNET& sub); template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output); template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
...@@ -684,7 +854,9 @@ namespace dlib ...@@ -684,7 +854,9 @@ namespace dlib
/*! /*!
ensures ensures
- #get_mode() == mode - #get_mode() == mode
- get_running_stats_window_size() == 1000 - #get_running_stats_window_size() == 1000
- #get_learning_rate_multiplier() == 1
- #get_weight_decay_multiplier() == 0
!*/ !*/
explicit bn_( explicit bn_(
...@@ -693,7 +865,9 @@ namespace dlib ...@@ -693,7 +865,9 @@ namespace dlib
/*! /*!
ensures ensures
- #get_mode() == mode - #get_mode() == mode
- get_running_stats_window_size() == window_size - #get_running_stats_window_size() == window_size
- #get_learning_rate_multiplier() == 1
- #get_weight_decay_multiplier() == 0
!*/ !*/
layer_mode get_mode( layer_mode get_mode(
...@@ -725,6 +899,44 @@ namespace dlib ...@@ -725,6 +899,44 @@ namespace dlib
the running average. the running average.
!*/ !*/
double get_learning_rate_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the learning rate used to optimize its parameters be
multiplied by get_learning_rate_multiplier().
!*/
double get_weight_decay_multiplier(
) const;
/*!
ensures
- returns a multiplier number. The interpretation is that this object is
requesting that the weight decay used to optimize its parameters be
multiplied by get_weight_decay_multiplier().
!*/
void set_learning_rate_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_learning_rate_multiplier() == val
!*/
void set_weight_decay_multiplier(
double val
);
/*!
requires
- val >= 0
ensures
- #get_weight_decay_multiplier() == val
!*/
template <typename SUBNET> void setup (const SUBNET& sub); template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output); template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "solvers_abstract.h" #include "solvers_abstract.h"
#include "tensor.h" #include "tensor.h"
#include <iostream> #include <iostream>
#include "layers.h"
namespace dlib namespace dlib
{ {
...@@ -49,10 +50,42 @@ namespace dlib ...@@ -49,10 +50,42 @@ namespace dlib
v = 0; v = 0;
} }
//perform: v = momentum*mat(v) - weight_decay*learning_rate*mat(params) - learning_rate*mat(params_grad); const double lr = learning_rate*get_learning_rate_multiplier(l);
tt::affine_transform(v, v, params, params_grad, const double wd = weight_decay*get_weight_decay_multiplier(l);
momentum, -weight_decay*learning_rate, -learning_rate, 0);
//perform: v = momentum*mat(v) - wd*lr*mat(params) - lr*mat(params_grad);
tt::affine_transform(v, v, params, params_grad, momentum, -wd*lr, -lr);
return v;
}
template <unsigned long N>
const tensor& operator() (
const float learning_rate,
const fc_<N,FC_HAS_BIAS>& l,
const tensor& params_grad
)
{
update_considering_bias(learning_rate, l, params_grad, l.get_num_outputs());
return v;
}
template <
long _num_filters,
long _nr,
long _nc,
int _stride_y,
int _stride_x,
int _padding_y,
int _padding_x
>
const tensor& operator() (
const float learning_rate,
const con_<_num_filters,_nr,_nc,_stride_y,_stride_x,_padding_y,_padding_x>& l,
const tensor& params_grad
)
{
update_considering_bias(learning_rate, l, params_grad, l.num_filters());
return v; return v;
} }
...@@ -76,9 +109,49 @@ namespace dlib ...@@ -76,9 +109,49 @@ namespace dlib
} }
private: private:
template <typename layer_type>
void update_considering_bias(
const float learning_rate,
const layer_type& l,
const tensor& params_grad,
unsigned long bias_offset
)
{
const tensor& params = l.get_layer_params();
DLIB_CASSERT(params.size() != 0,"");
if (v.size() == 0)
{
v.copy_size(params_grad);
v = 0;
}
double lr = learning_rate*get_learning_rate_multiplier(l);
double wd = weight_decay*get_weight_decay_multiplier(l);
//perform: v = momentum*mat(v) - wd*lr*mat(params) - lr*mat(params_grad);
if (l.get_bias_learning_rate_multiplier() == 1 && l.get_bias_weight_decay_multiplier() == 1)
{
tt::affine_transform(v, v, params, params_grad, momentum, -wd*lr, -lr);
}
else
{
tt::affine_transform_range(0, bias_offset, v, v, params, params_grad, momentum, -wd*lr, -lr);
// now update the biases but apply their multipliers
lr *= l.get_bias_learning_rate_multiplier();
wd *= l.get_bias_weight_decay_multiplier();
tt::affine_transform_range(bias_offset, v.size(), v, v, params, params_grad, momentum, -wd*lr, -lr);
}
}
resizable_tensor v; resizable_tensor v;
float weight_decay; float weight_decay;
float momentum; float momentum;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -132,11 +205,46 @@ namespace dlib ...@@ -132,11 +205,46 @@ namespace dlib
++t; ++t;
tt::compute_adam_update(s, m, v, t, learning_rate, weight_decay, momentum1, momentum2, params, params_grad);
tt::compute_adam_update(0, params.size(), s, m, v, t,
learning_rate*get_learning_rate_multiplier(l),
weight_decay*get_weight_decay_multiplier(l),
momentum1, momentum2, params, params_grad);
return s; return s;
} }
template <unsigned long N>
const tensor& operator() (
const float learning_rate,
const fc_<N,FC_HAS_BIAS>& l,
const tensor& params_grad
)
{
update_considering_bias(learning_rate, l, params_grad, l.get_num_outputs());
return s;
}
template <
long _num_filters,
long _nr,
long _nc,
int _stride_y,
int _stride_x,
int _padding_y,
int _padding_x
>
const tensor& operator() (
const float learning_rate,
const con_<_num_filters,_nr,_nc,_stride_y,_stride_x,_padding_y,_padding_x>& l,
const tensor& params_grad
)
{
update_considering_bias(learning_rate, l, params_grad, l.num_filters());
return s;
}
friend void serialize(const adam& item, std::ostream& out) friend void serialize(const adam& item, std::ostream& out)
{ {
serialize("adam2", out); serialize("adam2", out);
...@@ -165,6 +273,49 @@ namespace dlib ...@@ -165,6 +273,49 @@ namespace dlib
} }
private: private:
template <typename layer_type>
void update_considering_bias(
const float learning_rate,
const layer_type& l,
const tensor& params_grad,
unsigned long bias_offset
)
{
const tensor& params = l.get_layer_params();
DLIB_CASSERT(params.size() != 0,"");
if (v.size() == 0)
{
m.copy_size(params_grad);
m = 0;
v.copy_size(params_grad);
v = 0;
s.copy_size(params_grad);
}
++t;
if (l.get_bias_learning_rate_multiplier() == 1 && l.get_bias_weight_decay_multiplier() == 1)
{
tt::compute_adam_update(0, params.size(), s, m, v, t,
learning_rate*get_learning_rate_multiplier(l),
weight_decay*get_weight_decay_multiplier(l),
momentum1, momentum2, params, params_grad);
}
else
{
tt::compute_adam_update(0, bias_offset, s, m, v, t,
learning_rate*get_learning_rate_multiplier(l),
weight_decay*get_weight_decay_multiplier(l),
momentum1, momentum2, params, params_grad);
tt::compute_adam_update(bias_offset, params.size(), s, m, v, t,
learning_rate*get_learning_rate_multiplier(l)*l.get_bias_learning_rate_multiplier(),
weight_decay*get_weight_decay_multiplier(l)*l.get_bias_weight_decay_multiplier(),
momentum1, momentum2, params, params_grad);
}
}
resizable_tensor m; resizable_tensor m;
resizable_tensor v; resizable_tensor v;
resizable_tensor s; resizable_tensor s;
......
...@@ -78,6 +78,15 @@ namespace dlib ...@@ -78,6 +78,15 @@ namespace dlib
V = momentum*V - weight_decay*learning_rate*l.get_layer_params() - learning_rate*params_grad; V = momentum*V - weight_decay*learning_rate*l.get_layer_params() - learning_rate*params_grad;
Here V is a momentum term that is remembered by the solver from one Here V is a momentum term that is remembered by the solver from one
invocation of operator() to the next. invocation of operator() to the next.
Note that the actual learning rate and weight decay used by the solver are
multiplied by the per layer multipliers. That is, the solver will call
get_learning_rate_multiplier(l) and get_weight_decay_multiplier(l) and
multiply these values with the nominal learning rate and weight decay,
respectively, to determine the values it will use during each step. It is
also overloaded to allow additional learning rate multipliers to be applied
to fc_ and con_ bias parameters.
!*/ !*/
public: public:
...@@ -123,6 +132,15 @@ namespace dlib ...@@ -123,6 +132,15 @@ namespace dlib
paper: paper:
Kingma, Diederik P., and Jimmy Ba Adam. "A method for stochastic Kingma, Diederik P., and Jimmy Ba Adam. "A method for stochastic
optimization." International Conference on Learning Representation. 2015. optimization." International Conference on Learning Representation. 2015.
Note that the actual learning rate and weight decay used by the solver are
multiplied by the per layer multipliers. That is, the solver will call
get_learning_rate_multiplier(l) and get_weight_decay_multiplier(l) and
multiply these values with the nominal learning rate and weight decay,
respectively, to determine the values it will use during each step. It is
also overloaded to allow additional learning rate multipliers to be applied
to fc_ and con_ bias parameters.
!*/ !*/
public: public:
......
...@@ -311,6 +311,8 @@ namespace dlib { namespace tt ...@@ -311,6 +311,8 @@ namespace dlib { namespace tt
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void compute_adam_update ( void compute_adam_update (
size_t begin,
size_t end,
tensor& s, tensor& s,
tensor& m, tensor& m,
tensor& v, tensor& v,
...@@ -324,10 +326,10 @@ namespace dlib { namespace tt ...@@ -324,10 +326,10 @@ namespace dlib { namespace tt
) )
{ {
#ifdef DLIB_USE_CUDA #ifdef DLIB_USE_CUDA
cuda::compute_adam_update(s, m, v, t, learning_rate, weight_decay, momentum1, cuda::compute_adam_update(begin, end, s, m, v, t, learning_rate, weight_decay, momentum1,
momentum2, params, params_grad); momentum2, params, params_grad);
#else #else
cpu::compute_adam_update(s, m, v, t, learning_rate, weight_decay, momentum1, cpu::compute_adam_update(begin, end, s, m, v, t, learning_rate, weight_decay, momentum1,
momentum2, params, params_grad); momentum2, params, params_grad);
#endif #endif
} }
......
...@@ -335,6 +335,8 @@ namespace dlib { namespace tt ...@@ -335,6 +335,8 @@ namespace dlib { namespace tt
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void compute_adam_update ( void compute_adam_update (
size_t begin,
size_t end,
tensor& s, tensor& s,
tensor& m, tensor& m,
tensor& v, tensor& v,
...@@ -354,12 +356,16 @@ namespace dlib { namespace tt ...@@ -354,12 +356,16 @@ namespace dlib { namespace tt
- weight_decay >= 0 - weight_decay >= 0
- 0 <= momentum1 < 1 - 0 <= momentum1 < 1
- 0 <= momentum2 < 1 - 0 <= momentum2 < 1
- begin <= end <= params.size()
ensures ensures
- This function implements the ADAM parameter update method described in the paper: - This function implements the ADAM parameter update method described in the paper:
Kingma, Diederik P., and Jimmy Ba Adam. "A method for stochastic Kingma, Diederik P., and Jimmy Ba Adam. "A method for stochastic
optimization." International Conference on Learning Representation. 2015. optimization." International Conference on Learning Representation. 2015.
Specifically, it implements the method shown as Algorithm 1. Specifically, it implements the method shown as Algorithm 1.
- #s is the update vector that should be added to the parameters. - #s is the update vector that should be added to the parameters.
- The function only operates in the half open range [begin,end) of the memory
blocks of each tensor. E.g. to make this function run on the entire tensor
set begin to 0 and end to params.size().
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment