Commit 01b3b08b authored by Fm's avatar Fm

Replaced sizeof... with variadic templates

parent 1974e68d
...@@ -1844,6 +1844,7 @@ namespace dlib ...@@ -1844,6 +1844,7 @@ namespace dlib
}; };
template <template<typename> class TAG_TYPE> template <template<typename> class TAG_TYPE>
struct concat_helper_impl<TAG_TYPE>{ struct concat_helper_impl<TAG_TYPE>{
constexpr static size_t tag_count() {return 1;}
template<typename SUBNET> template<typename SUBNET>
static void resize_out(resizable_tensor& out, const SUBNET& sub, long sum_k) static void resize_out(resizable_tensor& out, const SUBNET& sub, long sum_k)
{ {
...@@ -1865,6 +1866,9 @@ namespace dlib ...@@ -1865,6 +1866,9 @@ namespace dlib
}; };
template <template<typename> class TAG_TYPE, template<typename> class... TAG_TYPES> template <template<typename> class TAG_TYPE, template<typename> class... TAG_TYPES>
struct concat_helper_impl<TAG_TYPE, TAG_TYPES...>{ struct concat_helper_impl<TAG_TYPE, TAG_TYPES...>{
constexpr static size_t tag_count() {return 1 + concat_helper_impl<TAG_TYPES...>::tag_count();}
template<typename SUBNET> template<typename SUBNET>
static void resize_out(resizable_tensor& out, const SUBNET& sub, long sum_k) static void resize_out(resizable_tensor& out, const SUBNET& sub, long sum_k)
{ {
...@@ -1896,6 +1900,8 @@ namespace dlib ...@@ -1896,6 +1900,8 @@ namespace dlib
class concat_ class concat_
{ {
public: public:
constexpr static size_t tag_count() {return impl::concat_helper_impl<TAG_TYPES...>::tag_count();};
template <typename SUBNET> template <typename SUBNET>
void setup (const SUBNET&) void setup (const SUBNET&)
{ {
...@@ -1924,7 +1930,8 @@ namespace dlib ...@@ -1924,7 +1930,8 @@ namespace dlib
friend void serialize(const concat_& item, std::ostream& out) friend void serialize(const concat_& item, std::ostream& out)
{ {
serialize("concat_", out); serialize("concat_", out);
serialize(sizeof...(TAG_TYPES), out); size_t count = tag_count();
serialize(count, out);
} }
friend void deserialize(concat_& item, std::istream& in) friend void deserialize(concat_& item, std::istream& in)
...@@ -1935,15 +1942,16 @@ namespace dlib ...@@ -1935,15 +1942,16 @@ namespace dlib
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::concat_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::concat_.");
size_t count_tags; size_t count_tags;
deserialize(count_tags, in); deserialize(count_tags, in);
if (count_tags != sizeof...(TAG_TYPES)) if (count_tags != tag_count())
throw serialization_error("Invalid count of tags "+ std::to_string(count_tags) +", expecting " + throw serialization_error("Invalid count of tags "+ std::to_string(count_tags) +", expecting " +
std::to_string(sizeof...(TAG_TYPES)) + " found while deserializing dlib::concat_."); std::to_string(tag_count()) +
" found while deserializing dlib::concat_.");
} }
friend std::ostream& operator<<(std::ostream& out, const concat_& item) friend std::ostream& operator<<(std::ostream& out, const concat_& item)
{ {
out << "concat\t (" out << "concat\t ("
<< sizeof...(TAG_TYPES) << tag_count()
<< ")"; << ")";
return out; return out;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment