Commit 29897305 authored by Davis King's avatar Davis King

Work around funny name lookup rules for serialize() call.

parent 32fb83b3
...@@ -2158,11 +2158,6 @@ namespace dlib ...@@ -2158,11 +2158,6 @@ namespace dlib
template < typename net_type, typename solver_type > friend class dnn_trainer; template < typename net_type, typename solver_type > friend class dnn_trainer;
}; };
template <typename LOSS_DETAILS, typename SUBNET>
void serialize(const add_loss_layer<LOSS_DETAILS,SUBNET>& item, std::ostream& out);
template <typename LOSS_DETAILS, typename SUBNET>
void deserialize(add_loss_layer<LOSS_DETAILS,SUBNET>& item, std::istream& in);
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename LOSS_DETAILS, typename SUBNET> template <typename LOSS_DETAILS, typename SUBNET>
...@@ -2457,23 +2452,10 @@ namespace dlib ...@@ -2457,23 +2452,10 @@ namespace dlib
subnetwork.clean(); subnetwork.clean();
} }
friend void serialize(const add_loss_layer& item, std::ostream& out) template <typename T, typename U>
{ friend void serialize(const add_loss_layer<T,U>& item, std::ostream& out);
int version = 1; template <typename T, typename U>
serialize(version, out); friend void deserialize(add_loss_layer<T,U>& item, std::istream& in);
serialize(item.loss, out);
serialize(item.subnetwork, out);
}
friend void deserialize(add_loss_layer& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::add_loss_layer.");
deserialize(item.loss, in);
deserialize(item.subnetwork, in);
}
friend std::ostream& operator<< (std::ostream& out, const add_loss_layer& item) friend std::ostream& operator<< (std::ostream& out, const add_loss_layer& item)
{ {
...@@ -2506,6 +2488,26 @@ namespace dlib ...@@ -2506,6 +2488,26 @@ namespace dlib
resizable_tensor temp_tensor; resizable_tensor temp_tensor;
}; };
template <typename LOSS_DETAILS, typename SUBNET>
void serialize(const add_loss_layer<LOSS_DETAILS,SUBNET>& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.loss, out);
serialize(item.subnetwork, out);
}
template <typename LOSS_DETAILS, typename SUBNET>
void deserialize(add_loss_layer<LOSS_DETAILS,SUBNET>& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::add_loss_layer.");
deserialize(item.loss, in);
deserialize(item.subnetwork, in);
}
template <typename T, typename U> template <typename T, typename U>
struct is_loss_layer_type<add_loss_layer<T,U>> : std::true_type {}; struct is_loss_layer_type<add_loss_layer<T,U>> : std::true_type {};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment