Commit 3b35f55b authored by Davis King's avatar Davis King

Added visit_layer_parameters()

parent 1fc117e1
...@@ -83,6 +83,8 @@ namespace dlib ...@@ -83,6 +83,8 @@ namespace dlib
template <typename T> struct is_nonloss_layer_type : std::false_type {}; template <typename T> struct is_nonloss_layer_type : std::false_type {};
// Tell us if T is an instance of add_loss_layer. // Tell us if T is an instance of add_loss_layer.
template <typename T> struct is_loss_layer_type : std::false_type {}; template <typename T> struct is_loss_layer_type : std::false_type {};
// Tell us if T is an instance of add_layer
template <typename T> struct is_add_layer : std::false_type {};
namespace impl namespace impl
{ {
...@@ -540,6 +542,7 @@ namespace dlib ...@@ -540,6 +542,7 @@ namespace dlib
template <typename LAYER_DETAILS, typename SUBNET, typename enabled = void> template <typename LAYER_DETAILS, typename SUBNET, typename enabled = void>
class add_layer; class add_layer;
template <typename T, typename U> template <typename T, typename U>
struct is_nonloss_layer_type<add_layer<T,U>> : std::true_type {}; struct is_nonloss_layer_type<add_layer<T,U>> : std::true_type {};
...@@ -947,6 +950,15 @@ namespace dlib ...@@ -947,6 +950,15 @@ namespace dlib
}; };
template <typename T, typename U, typename E>
struct is_add_layer<add_layer<T,U,E>> : std::true_type {};
template <typename T, typename U, typename E>
struct is_add_layer<const add_layer<T,U,E>> : std::true_type {};
template <typename T, typename U, typename E>
struct is_add_layer<add_layer<T,U,E>&> : std::true_type {};
template <typename T, typename U, typename E>
struct is_add_layer<const add_layer<T,U,E>&> : std::true_type {};
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// This version of add_layer handles the special case where the subnetwork being given is // This version of add_layer handles the special case where the subnetwork being given is
...@@ -2970,6 +2982,73 @@ namespace dlib ...@@ -2970,6 +2982,73 @@ namespace dlib
return impl_test_layer(l, 0.01); return impl_test_layer(l, 0.01);
} }
// ----------------------------------------------------------------------------------------
namespace impl
{
template <size_t i, size_t num>
struct vlp_loop
{
template <typename T, typename U>
static typename std::enable_if<!is_add_layer<U>::value>::type invoke_functor(T&& , size_t& , U&& )
{
// intentionally left empty
}
template <typename T, typename U>
static typename std::enable_if<is_add_layer<U>::value>::type invoke_functor(T&& v , size_t& comp_i, U&& l )
{
v(comp_i, l.layer_details().get_layer_params());
++comp_i;
}
template <
typename net_type,
typename visitor
>
static void visit(
size_t comp_i,
net_type& net,
visitor&& v
)
{
invoke_functor(v, comp_i, layer<i>(net));
vlp_loop<i+1, num>::visit(comp_i, net,v);
}
};
template <size_t num>
struct vlp_loop<num,num>
{
template <
typename net_type,
typename visitor
>
static void visit(
size_t,
net_type&,
visitor&&
)
{
// Base case of recursion. Don't do anything.
}
};
}
template <
typename net_type,
typename visitor
>
void visit_layer_parameters(
net_type& net,
visitor v
)
{
size_t comp_i = 0;
impl::vlp_loop<0, net_type::num_layers>::visit(comp_i, net, v);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
...@@ -1315,6 +1315,39 @@ namespace dlib ...@@ -1315,6 +1315,39 @@ namespace dlib
- returns layer<i>(layer<Match>(n)) - returns layer<i>(layer<Match>(n))
!*/ !*/
// ----------------------------------------------------------------------------------------
template <
typename net_type,
typename visitor
>
void visit_layer_parameters(
net_type& net,
visitor v
);
/*!
requires
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
add_tag_layer.
- v is a function object with a signature equivalent to:
v(size_t idx, tensor& t)
ensures
- Loops over all the computational layers (i.e. layers with parameters, as
opposed to loss, tag, or input layers) in net and passes their parameters to
v(). To be specific, this function essentially performs the following:
size_t computational_layer_idx = 0;
for (size_t i = 0; i < net_type::num_layers; ++i)
{
if (layer<i>(net) is a computational layer)
{
v(computational_layer_idx, layer<i>(net).layer_details().get_layer_params());
++computational_layer_idx;
}
}
- When v() is called, the first argument is always < net_type::num_computational_layers.
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
struct layer_test_results struct layer_test_results
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment