Commit 15b2d7b5 authored by Davis King's avatar Davis King

Added get_learning_rate_multiplier() and get_weight_decay_multiplier() global

functions.
parent 58496f9f
......@@ -488,6 +488,13 @@ namespace dlib
// ----------------------------------------------------------------------------------------
struct general_ {};
struct special_ : general_ {};
template<typename> struct int_ { typedef int type; };
// ----------------------------------------------------------------------------------------
/*!A is_same_object
This is a templated function which checks if both of its arguments are actually
......
......@@ -24,6 +24,38 @@
namespace dlib
{
// ----------------------------------------------------------------------------------------
namespace impl
{
template <typename T, typename int_<decltype(&T::get_learning_rate_multiplier)>::type = 0>
double get_learning_rate_multiplier (
const T& obj,
special_
) { return obj.get_learning_rate_multiplier(); }
template <typename T>
double get_learning_rate_multiplier ( const T& obj, general_) { return 1; }
}
template <typename T>
double get_learning_rate_multiplier(const T& obj) { return impl::get_learning_rate_multiplier(obj, special_()); }
// ----------------------------------------------------------------------------------------
namespace impl
{
template <typename T, typename int_<decltype(&T::get_weight_decay_multiplier)>::type = 0>
double get_weight_decay_multiplier (
const T& obj,
special_
) { return obj.get_weight_decay_multiplier(); }
template <typename T>
double get_weight_decay_multiplier ( const T& obj, general_) { return 1; }
}
template <typename T>
double get_weight_decay_multiplier(const T& obj) { return impl::get_weight_decay_multiplier(obj, special_()); }
// ----------------------------------------------------------------------------------------
namespace impl
......@@ -849,8 +881,9 @@ namespace dlib
void update_parameters(sstack<solver_type> solvers, double learning_rate)
{
DLIB_CASSERT(solvers.size()>=num_computational_layers,"");
// Don't try to adjust the parameters if this layer doesn't have any.
if (params_grad.size() != 0)
// Don't try to adjust the parameters if this layer doesn't have any or the
// learning rate is disabled for this layer.
if (params_grad.size() != 0 && get_learning_rate_multiplier(details) != 0)
{
const tensor& step = solvers.top()(learning_rate, details, static_cast<const tensor&>(params_grad));
tt::add(details.get_layer_params(), details.get_layer_params(), step);
......@@ -1200,8 +1233,9 @@ namespace dlib
void update_parameters(sstack<solver_type> solvers, double learning_rate)
{
DLIB_CASSERT(solvers.size()>=num_computational_layers,"");
// Don't try to adjust the parameters if this layer doesn't have any.
if (params_grad.size() != 0)
// Don't try to adjust the parameters if this layer doesn't have any or the
// learning rate is disabled for this layer.
if (params_grad.size() != 0 && get_learning_rate_multiplier(details) != 0)
{
const tensor& step = solvers.top()(learning_rate, details, static_cast<const tensor&>(params_grad));
tt::add(details.get_layer_params(), details.get_layer_params(), step);
......
......@@ -67,6 +67,32 @@ namespace dlib
(except computes it using a numerically accurate method)
!*/
// ----------------------------------------------------------------------------------------
template <typename T>
double get_learning_rate_multiplier(
const T& obj
);
/*!
ensures
- if (obj has a get_learning_rate_multiplier() member function) then
- returns obj.get_learning_rate_multiplier()
- else
- returns 1
!*/
template <typename T>
double get_weight_decay_multiplier(
const T& obj
);
/*!
ensures
- if (obj has a get_weight_decay_multiplier() member function) then
- returns obj.get_weight_decay_multiplier()
- else
- returns 1
!*/
// ----------------------------------------------------------------------------------------
bool dnn_prefer_fastest_algorithms(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment