Commit 390c8e90 authored by Davis King's avatar Davis King

Made layer_details() part of the SUBNET interface so that user defined layer

details objects can access each other.  Also added the input_layer() global
function for accessing the input layer specifically.
parent 285bba76
......@@ -503,9 +503,13 @@ namespace dlib
subnet_wrapper(const subnet_wrapper&) = delete;
subnet_wrapper& operator=(const subnet_wrapper&) = delete;
subnet_wrapper(T& /*l_*/) {}
// Nothing here because in this case T is one of the input layer types
subnet_wrapper(T& l_) : l(l_) {}
// Not much here because in this case T is one of the input layer types
// that doesn't have anything in it.
typedef T layer_details_type;
const layer_details_type& layer_details() const { return l; }
private:
T& l;
};
template <typename T>
......@@ -518,12 +522,16 @@ namespace dlib
typedef T wrapped_type;
const static size_t num_computational_layers = T::num_computational_layers;
const static size_t num_layers = T::num_layers;
typedef typename T::layer_details_type layer_details_type;
subnet_wrapper(T& l_) : l(l_),subnetwork(l.subnet()) {}
const tensor& get_output() const { return l.private_get_output(); }
tensor& get_gradient_input() { return l.private_get_gradient_input(); }
const layer_details_type& layer_details() const { return l.layer_details(); }
const subnet_wrapper<typename T::subnet_type,false>& subnet() const { return subnetwork; }
subnet_wrapper<typename T::subnet_type,false>& subnet() { return subnetwork; }
......@@ -542,12 +550,16 @@ namespace dlib
typedef T wrapped_type;
const static size_t num_computational_layers = T::num_computational_layers;
const static size_t num_layers = T::num_layers;
typedef typename T::layer_details_type layer_details_type;
subnet_wrapper(T& l_) : l(l_),subnetwork(l.subnet()) {}
const tensor& get_output() const { return l.get_output(); }
tensor& get_gradient_input() { return l.get_gradient_input(); }
const layer_details_type& layer_details() const { return l.layer_details(); }
const subnet_wrapper<typename T::subnet_type,false>& subnet() const { return subnetwork; }
subnet_wrapper<typename T::subnet_type,false>& subnet() { return subnetwork; }
......@@ -1358,6 +1370,7 @@ namespace dlib
public:
typedef SUBNET subnet_type;
typedef typename subnet_type::input_type input_type;
typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
const static size_t num_layers = subnet_type::num_layers + 1;
const static size_t num_computational_layers = subnet_type::num_computational_layers;
const static unsigned int sample_expansion_factor = subnet_type::sample_expansion_factor;
......@@ -1554,6 +1567,7 @@ namespace dlib
public:
typedef SUBNET subnet_type;
typedef typename SUBNET::input_type input_type;
typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
const static size_t comp_layers_in_each_group = (REPEATED_LAYER<SUBNET>::num_computational_layers-SUBNET::num_computational_layers);
const static size_t comp_layers_in_repeated_group = comp_layers_in_each_group*num;
const static size_t num_computational_layers = comp_layers_in_repeated_group + SUBNET::num_computational_layers;
......@@ -1825,6 +1839,7 @@ namespace dlib
public:
typedef INPUT_LAYER subnet_type;
typedef typename subnet_type::input_type input_type;
typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
const static size_t num_computational_layers = 0;
const static size_t num_layers = 2;
const static unsigned int sample_expansion_factor = subnet_type::sample_expansion_factor;
......@@ -2544,6 +2559,16 @@ namespace dlib
return impl::layer_helper_match<Match,T,i>::layer(n);
}
// ----------------------------------------------------------------------------------------
template <typename net_type>
auto input_layer (
net_type& net
) -> decltype(layer<net_type::num_layers-1>(net))&
{
return layer<net_type::num_layers-1>(net);
}
// ----------------------------------------------------------------------------------------
template <template<typename> class TAG_TYPE, typename SUBNET>
......@@ -2552,6 +2577,7 @@ namespace dlib
public:
typedef SUBNET subnet_type;
typedef typename subnet_type::input_type input_type;
typedef int layer_details_type; // not really used anywhere, but required by subnet_wrapper.
const static size_t num_layers = subnet_type::num_layers + 1;
const static size_t num_computational_layers = subnet_type::num_computational_layers;
const static unsigned int sample_expansion_factor = subnet_type::sample_expansion_factor;
......
......@@ -1332,6 +1332,22 @@ namespace dlib
- returns layer<i>(layer<Match>(n))
!*/
// ----------------------------------------------------------------------------------------
template <typename net_type>
auto& input_layer (
net_type& net
);
/*!
requires
- net_type is an object of type add_layer, add_loss_layer, add_skip_layer, or
add_tag_layer.
ensures
- returns the input later of the given network object. Specifically, this
function is equivalent to calling:
layer<net_type::num_layers-1>(net);
!*/
// ----------------------------------------------------------------------------------------
template <
......
......@@ -82,6 +82,18 @@ namespace dlib
above, if *this was layer1 then subnet() would return the network that
begins with layer2.
!*/
const layer_details_type& layer_details(
) const;
/*!
ensures
- returns the layer_details_type instance that defines the behavior of the
layer at the top of this network. I.e. returns the layer details that
defines the behavior of the layer nearest to the network output rather
than the input layer. For computational layers, this is the object
implementing the EXAMPLE_COMPUTATIONAL_LAYER_ interface that defines the
layer's behavior.
!*/
};
// ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment