Commit b66c5254 authored by Davis King's avatar Davis King

Made the tuple based layer constructors work with nested tuples so you can

define combination layers made out of other combination layers without being
hassled by the compiler.
parent d2516bc2
...@@ -59,7 +59,7 @@ namespace dlib ...@@ -59,7 +59,7 @@ namespace dlib
template <size_t max> template <size_t max>
struct ct_make_integer_range struct ct_make_integer_range
{ {
// recursively call push_back on ct_integers_list to build a range from 0 to max // recursively call push_back on ct_integers_list to build a range from 1 to max
// inclusive. // inclusive.
typedef typename ct_make_integer_range<max-1>::type::template push_back<max>::type type; typedef typename ct_make_integer_range<max-1>::type::template push_back<max>::type type;
}; };
...@@ -79,6 +79,57 @@ namespace dlib ...@@ -79,6 +79,57 @@ namespace dlib
return std::make_tuple(std::get<indices>(item)...); return std::make_tuple(std::get<indices>(item)...);
} }
template <typename Head, typename... Tail>
std::tuple<Tail...> basic_tuple_tail(
const std::tuple<Head, Tail...>& item
)
{
return tuple_subset(item, typename ct_make_integer_range<sizeof...(Tail)>::type());
}
template <typename T>
std::tuple<T> tuple_flatten(const T& t)
{
return std::make_tuple(t);
}
template <typename... T>
auto tuple_flatten(
const std::tuple<T...>& item
) -> decltype(tuple_flatten(item, typename ct_make_integer_range<sizeof...(T)>::type()))
{
return tuple_flatten(item, typename ct_make_integer_range<sizeof...(T)>::type());
}
template <size_t... indices, typename... T>
auto tuple_flatten(
const std::tuple<T...>& item,
ct_integers_list<indices...>
) -> decltype(std::tuple_cat(tuple_flatten(std::get<indices-1>(item))...))
{
return std::tuple_cat(tuple_flatten(std::get<indices-1>(item))...);
}
template <typename T>
struct tuple_head_helper
{
typedef T type;
static const type& get(const T& item)
{
return item;
}
};
template <typename T, typename... U>
struct tuple_head_helper<std::tuple<T, U...>>
{
typedef typename tuple_head_helper<T>::type type;
static const type& get(const std::tuple<T,U...>& item)
{
return tuple_head_helper<T>::get(std::get<0>(item));
}
};
template <typename T> struct alwaysbool { typedef bool type; }; template <typename T> struct alwaysbool { typedef bool type; };
resizable_tensor& rt(); resizable_tensor& rt();
...@@ -275,14 +326,28 @@ namespace dlib ...@@ -275,14 +326,28 @@ namespace dlib
} // end namespace impl } // end namespace impl
template <typename Head, typename... Tail> template <typename... T>
std::tuple<Tail...> tuple_tail( typename impl::tuple_head_helper<std::tuple<T...>>::type tuple_head (
const std::tuple<Head, Tail...>& item const std::tuple<T...>& item
) )
{ {
return impl::tuple_subset(item, typename impl::ct_make_integer_range<sizeof...(Tail)>::type()); return impl::tuple_head_helper<std::tuple<T...>>::get(item);
} }
template <typename... T>
auto tuple_tail(
const std::tuple<T...>& item
) -> decltype(impl::basic_tuple_tail(impl::tuple_flatten(item)))
{
return impl::basic_tuple_tail(impl::tuple_flatten(item));
}
inline std::tuple<> tuple_tail(
const std::tuple<>& item
)
{
return item;
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
inline void randomize_parameters ( inline void randomize_parameters (
...@@ -540,12 +605,12 @@ namespace dlib ...@@ -540,12 +605,12 @@ namespace dlib
subnetwork.disable_output_and_gradient_getters(); subnetwork.disable_output_and_gradient_getters();
} }
template <typename ...T, typename ...U> template <typename ...T, typename LD, typename ...U>
add_layer( add_layer(
const std::tuple<LAYER_DETAILS,U...>& layer_det, const std::tuple<LD,U...>& layer_det,
T&& ...args T&& ...args
) : ) :
details(std::get<0>(layer_det)), details(tuple_head(layer_det)),
subnetwork(tuple_tail(layer_det),std::forward<T>(args)...), subnetwork(tuple_tail(layer_det),std::forward<T>(args)...),
this_layer_setup_called(false), this_layer_setup_called(false),
gradient_input_is_stale(true), gradient_input_is_stale(true),
...@@ -555,10 +620,10 @@ namespace dlib ...@@ -555,10 +620,10 @@ namespace dlib
subnetwork.disable_output_and_gradient_getters(); subnetwork.disable_output_and_gradient_getters();
} }
template <typename ...T, typename ...U> template <typename ...T, typename LD, typename ...U>
add_layer( add_layer(
std::tuple<>, std::tuple<>,
const std::tuple<LAYER_DETAILS,U...>& layer_det, const std::tuple<LD,U...>& layer_det,
T&& ...args T&& ...args
) : add_layer(layer_det,args...) { } ) : add_layer(layer_det,args...) { }
...@@ -859,12 +924,12 @@ namespace dlib ...@@ -859,12 +924,12 @@ namespace dlib
add_layer( add_layer(
const std::tuple<LAYER_DETAILS>& layer_det const std::tuple<LAYER_DETAILS>& layer_det
) : add_layer(std::get<0>(layer_det)) {} ) : add_layer(tuple_head(layer_det)) {}
add_layer( add_layer(
const std::tuple<LAYER_DETAILS>& layer_det, const std::tuple<LAYER_DETAILS>& layer_det,
INPUT_LAYER il INPUT_LAYER il
) : add_layer(std::get<0>(layer_det),il) {} ) : add_layer(tuple_head(layer_det),il) {}
template <typename input_iterator> template <typename input_iterator>
void to_tensor ( void to_tensor (
......
...@@ -33,16 +33,28 @@ namespace dlib ...@@ -33,16 +33,28 @@ namespace dlib
!*/ !*/
template < template <
typename Head, typename... T
typename... Tail
> >
std::tuple<Tail...> tuple_tail( auto tuple_tail(
const std::tuple<Head, Tail...>& item const std::tuple<T...>& item
); );
/*! /*!
ensures ensures
- returns a tuple that contains everything in item except for get<0>(item). - returns a tuple that contains everything in item except for tuple_head(item).
So it basically returns make_tuple(get<1>(item),get<2>(item),get<3>(item), and so on). The items will be in the same order as they are in item, just without
tuple_head(item).
- This function will correctly handle nested tuples.
!*/
template <typename... T>
auto tuple_head (
const std::tuple<T...>& item
);
/*!
ensures
- returns a copy of the first thing in the tuple that isn't a std::tuple.
Essentially, this function calls std::get<0>() recursively on item until
a non-std::tuple object is found.
!*/ !*/
double log1pexp( double log1pexp(
...@@ -214,14 +226,14 @@ namespace dlib ...@@ -214,14 +226,14 @@ namespace dlib
- #subnet() == subnet_type(item.subnet()) - #subnet() == subnet_type(item.subnet())
!*/ !*/
template <typename ...T, typename ...U> template <typename ...T, typename LD, typename ...U>
add_layer( add_layer(
const std::tuple<layer_details_type,U...>& layer_det, const std::tuple<LD,U...>& layer_det,
T&& ...args T&& ...args
); );
/*! /*!
ensures ensures
- #layer_details() == layer_details_type(get<0>(layer_det)) - #layer_details() == layer_details_type(tuple_head(layer_det))
- #subnet() == subnet_type(tuple_tail(layer_det),args) - #subnet() == subnet_type(tuple_tail(layer_det),args)
!*/ !*/
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment