Commit 6c36592c authored by Davis King's avatar Davis King

Added serialization support to everything.

parent e679d66a
...@@ -67,6 +67,18 @@ namespace dlib ...@@ -67,6 +67,18 @@ namespace dlib
const sstack<T,N-1>& pop() const { return data; } const sstack<T,N-1>& pop() const { return data; }
sstack<T,N-1>& pop() { return data; } sstack<T,N-1>& pop() { return data; }
friend void serialize(const sstack& item, std::ostream& out)
{
serialize(item.top(), out);
serialize(item.pop(), out);
}
friend void deserialize(sstack& item, std::istream& in)
{
deserialize(item.top(), in);
deserialize(item.pop(), in);
}
private: private:
T item; T item;
sstack<T,N-1> data; sstack<T,N-1> data;
...@@ -83,6 +95,17 @@ namespace dlib ...@@ -83,6 +95,17 @@ namespace dlib
T& top() { return item; } T& top() { return item; }
size_t size() const { return 1; } size_t size() const { return 1; }
friend void serialize(const sstack& item, std::ostream& out)
{
serialize(item.top(), out);
}
friend void deserialize(sstack& item, std::istream& in)
{
deserialize(item.top(), in);
}
private: private:
T item; T item;
}; };
...@@ -294,6 +317,32 @@ namespace dlib ...@@ -294,6 +317,32 @@ namespace dlib
subnetwork.clean(); subnetwork.clean();
} }
friend void serialize(const add_layer& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.subnetwork, out);
serialize(item.details, out);
serialize(item.this_layer_setup_called, out);
serialize(item.gradient_input_is_stale, out);
serialize(item.x_grad, out);
serialize(item.cached_output, out);
}
friend void deserialize(add_layer& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::add_layer.");
deserialize(item.subnetwork, in);
deserialize(item.details, in);
deserialize(item.this_layer_setup_called, in);
deserialize(item.gradient_input_is_stale, in);
deserialize(item.x_grad, in);
deserialize(item.cached_output, in);
}
private: private:
...@@ -468,6 +517,32 @@ namespace dlib ...@@ -468,6 +517,32 @@ namespace dlib
gradient_input_is_stale = true; gradient_input_is_stale = true;
} }
friend void serialize(const add_layer& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.input_layer, out);
serialize(item.details, out);
serialize(item.this_layer_setup_called, out);
serialize(item.gradient_input_is_stale, out);
serialize(item.x_grad, out);
serialize(item.cached_output, out);
}
friend void deserialize(add_layer& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::add_layer.");
deserialize(item.input_layer, in);
deserialize(item.details, in);
deserialize(item.this_layer_setup_called, in);
deserialize(item.gradient_input_is_stale, in);
deserialize(item.x_grad, in);
deserialize(item.cached_output, in);
}
private: private:
class subnet_wrapper class subnet_wrapper
...@@ -601,6 +676,22 @@ namespace dlib ...@@ -601,6 +676,22 @@ namespace dlib
subnetwork.clean(); subnetwork.clean();
} }
friend void serialize(const add_tag_layer& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.subnetwork, out);
}
friend void deserialize(add_tag_layer& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::add_tag_layer.");
deserialize(item.subnetwork, in);
}
private: private:
subnet_type subnetwork; subnet_type subnetwork;
...@@ -702,6 +793,26 @@ namespace dlib ...@@ -702,6 +793,26 @@ namespace dlib
cached_output.clear(); cached_output.clear();
} }
friend void serialize(const add_tag_layer& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.input_layer, out);
serialize(item.cached_output, out);
serialize(item.grad_final_ignored, out);
}
friend void deserialize(add_tag_layer& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::add_tag_layer.");
deserialize(item.input_layer, in);
deserialize(item.cached_output, in);
deserialize(item.grad_final_ignored, in);
}
private: private:
subnet_type input_layer; subnet_type input_layer;
...@@ -759,7 +870,8 @@ namespace dlib ...@@ -759,7 +870,8 @@ namespace dlib
const static unsigned int sample_expansion_factor = subnet_type::sample_expansion_factor; const static unsigned int sample_expansion_factor = subnet_type::sample_expansion_factor;
typedef typename get_loss_layer_label_type<LOSS_DETAILS>::type label_type; typedef typename get_loss_layer_label_type<LOSS_DETAILS>::type label_type;
static_assert(is_nonloss_layer_type<SUBNET>::value, "SUBNET must be of type add_layer, add_skip_layer, or add_tag_layer."); static_assert(is_nonloss_layer_type<SUBNET>::value,
"SUBNET must be of type add_layer, add_skip_layer, or add_tag_layer.");
static_assert(sample_expansion_factor == LOSS_DETAILS::sample_expansion_factor, static_assert(sample_expansion_factor == LOSS_DETAILS::sample_expansion_factor,
"The loss layer and input layer must agree on the sample_expansion_factor."); "The loss layer and input layer must agree on the sample_expansion_factor.");
...@@ -947,6 +1059,24 @@ namespace dlib ...@@ -947,6 +1059,24 @@ namespace dlib
subnetwork.clear(); subnetwork.clear();
} }
friend void serialize(const add_loss_layer& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.loss, out);
serialize(item.subnetwork, out);
}
friend void deserialize(add_loss_layer& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::add_loss_layer.");
deserialize(item.loss, in);
deserialize(item.subnetwork, in);
}
private: private:
loss_details_type loss; loss_details_type loss;
...@@ -1150,6 +1280,22 @@ namespace dlib ...@@ -1150,6 +1280,22 @@ namespace dlib
subnetwork.clean(); subnetwork.clean();
} }
friend void serialize(const add_skip_layer& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.subnetwork, out);
}
friend void deserialize(add_skip_layer& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::add_skip_layer.");
deserialize(item.subnetwork, in);
}
private: private:
subnet_type subnetwork; subnet_type subnetwork;
......
...@@ -119,6 +119,12 @@ namespace dlib ...@@ -119,6 +119,12 @@ namespace dlib
!*/ !*/
}; };
void serialize(const sstack& item, std::ostream& out);
void deserialize(sstack& item, std::istream& in);
/*!
provides serialization support
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -378,6 +384,14 @@ namespace dlib ...@@ -378,6 +384,14 @@ namespace dlib
}; };
template <typename T, typename U>,
void serialize(const add_layer<T,U>& item, std::ostream& out);
template <typename T, typename U>,
void deserialize(add_layer<T,U>& item, std::istream& in);
/*!
provides serialization support
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -769,6 +783,14 @@ namespace dlib ...@@ -769,6 +783,14 @@ namespace dlib
!*/ !*/
}; };
template <typename T, typename U>,
void serialize(const add_loss_layer<T,U>& item, std::ostream& out);
template <typename T, typename U>,
void deserialize(add_loss_layer<T,U>& item, std::istream& in);
/*!
provides serialization support
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -799,6 +821,14 @@ namespace dlib ...@@ -799,6 +821,14 @@ namespace dlib
!*/ !*/
}; };
template <unsigned long ID, typename U>,
void serialize(const add_tag_layer<ID,U>& item, std::ostream& out);
template <unsigned long ID, typename U>,
void deserialize(add_tag_layer<ID,U>& item, std::istream& in);
/*!
provides serialization support
!*/
template <typename SUBNET> using tag1 = add_tag_layer< 1, SUBNET>; template <typename SUBNET> using tag1 = add_tag_layer< 1, SUBNET>;
template <typename SUBNET> using tag2 = add_tag_layer< 2, SUBNET>; template <typename SUBNET> using tag2 = add_tag_layer< 2, SUBNET>;
template <typename SUBNET> using tag3 = add_tag_layer< 3, SUBNET>; template <typename SUBNET> using tag3 = add_tag_layer< 3, SUBNET>;
...@@ -834,6 +864,14 @@ namespace dlib ...@@ -834,6 +864,14 @@ namespace dlib
!*/ !*/
}; };
template <template<typename> class T, typename U>
void serialize(const add_skip_layer<T,U>& item, std::ostream& out);
template <template<typename> class T, typename U>
void deserialize(add_skip_layer<T,U>& item, std::istream& in);
/*!
provides serialization support
!*/
template <typename SUBNET> using skip1 = add_skip_layer< tag1, SUBNET>; template <typename SUBNET> using skip1 = add_skip_layer< tag1, SUBNET>;
template <typename SUBNET> using skip2 = add_skip_layer< tag2, SUBNET>; template <typename SUBNET> using skip2 = add_skip_layer< tag2, SUBNET>;
template <typename SUBNET> using skip3 = add_skip_layer< tag3, SUBNET>; template <typename SUBNET> using skip3 = add_skip_layer< tag3, SUBNET>;
......
...@@ -73,6 +73,20 @@ namespace dlib ...@@ -73,6 +73,20 @@ namespace dlib
} }
} }
friend void serialize(const input& item, std::ostream& out)
{
serialize("input<matrix>", out);
}
friend void deserialize(input& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "input<matrix>")
throw serialization_error("Unexpected version found while deserializing dlib::input.");
}
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -126,6 +140,20 @@ namespace dlib ...@@ -126,6 +140,20 @@ namespace dlib
} }
} }
friend void serialize(const input& item, std::ostream& out)
{
serialize("input<array2d>", out);
}
friend void deserialize(input& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "input<array2d>")
throw serialization_error("Unexpected version found while deserializing dlib::input.");
}
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -86,6 +86,12 @@ namespace dlib ...@@ -86,6 +86,12 @@ namespace dlib
!*/ !*/
}; };
void serialize(const EXAMPLE_INPUT_LAYER& item, std::ostream& out);
void deserialize(EXAMPLE_INPUT_LAYER& item, std::istream& in);
/*!
provides serialization support
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -132,6 +138,14 @@ namespace dlib ...@@ -132,6 +138,14 @@ namespace dlib
!*/ !*/
}; };
template <typename T>
void serialize(const input<T>& item, std::ostream& out);
template <typename T>
void deserialize(input<T>& item, std::istream& in);
/*!
provides serialization support
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
...@@ -59,13 +59,12 @@ namespace dlib ...@@ -59,13 +59,12 @@ namespace dlib
public: public:
fc_() : num_outputs(1) fc_() : num_outputs(1)
{ {
rnd.set_seed("fc_" + cast_to_string(num_outputs));
} }
explicit fc_(unsigned long num_outputs_) explicit fc_(
unsigned long num_outputs_
) : num_outputs(num_outputs_)
{ {
num_outputs = num_outputs_;
rnd.set_seed("fc_" + cast_to_string(num_outputs));
} }
unsigned long get_num_outputs ( unsigned long get_num_outputs (
...@@ -77,6 +76,7 @@ namespace dlib ...@@ -77,6 +76,7 @@ namespace dlib
num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k(); num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k();
params.set_size(num_inputs, num_outputs); params.set_size(num_inputs, num_outputs);
dlib::rand rnd("fc_"+cast_to_string(num_outputs));
randomize_parameters(params, num_inputs+num_outputs, rnd); randomize_parameters(params, num_inputs+num_outputs, rnd);
} }
...@@ -101,12 +101,30 @@ namespace dlib ...@@ -101,12 +101,30 @@ namespace dlib
const tensor& get_layer_params() const { return params; } const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; } tensor& get_layer_params() { return params; }
friend void serialize(const fc_& item, std::ostream& out)
{
serialize("fc_", out);
serialize(item.num_outputs, out);
serialize(item.num_inputs, out);
serialize(item.params, out);
}
friend void deserialize(fc_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "fc_")
throw serialization_error("Unexpected version found while deserializing dlib::fc_.");
deserialize(item.num_outputs, in);
deserialize(item.num_inputs, in);
deserialize(item.params, in);
}
private: private:
unsigned long num_outputs; unsigned long num_outputs;
unsigned long num_inputs; unsigned long num_inputs;
resizable_tensor params; resizable_tensor params;
dlib::rand rnd;
}; };
...@@ -151,81 +169,28 @@ namespace dlib ...@@ -151,81 +169,28 @@ namespace dlib
const tensor& get_layer_params() const { return params; } const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; } tensor& get_layer_params() { return params; }
private: friend void serialize(const relu_& item, std::ostream& out)
resizable_tensor params;
};
template <typename SUBNET>
using relu = add_layer<relu_, SUBNET>;
// ----------------------------------------------------------------------------------------
class multiply_
{
public:
multiply_()
{
}
template <typename SUBNET>
void setup (const SUBNET& sub)
{ {
num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k(); serialize("relu_", out);
params.set_size(1, num_inputs);
std::cout << "multiply_::setup() " << params.size() << std::endl;
const int num_outputs = num_inputs;
randomize_parameters(params, num_inputs+num_outputs, rnd);
} }
template <typename SUBNET> friend void deserialize(relu_& item, std::istream& in)
void forward(const SUBNET& sub, resizable_tensor& output)
{
DLIB_CASSERT( sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k() == params.size(), "");
DLIB_CASSERT( sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k() == num_inputs, "");
output.copy_size(sub.get_output());
auto indata = sub.get_output().host();
auto outdata = output.host();
auto paramdata = params.host();
for (int i = 0; i < sub.get_output().num_samples(); ++i)
{
for (int j = 0; j < num_inputs; ++j)
{
*outdata++ = *indata++ * paramdata[j];
}
}
}
template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
{ {
params_grad += sum_rows(pointwise_multiply(mat(sub.get_output()),mat(gradient_input))); std::string version;
deserialize(version, in);
for (long i = 0; i < gradient_input.num_samples(); ++i) if (version != "relu_")
{ throw serialization_error("Unexpected version found while deserializing dlib::relu_.");
sub.get_gradient_input().add_to_sample(i,
pointwise_multiply(rowm(mat(gradient_input),i), mat(params)));
}
} }
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
private: private:
int num_inputs;
resizable_tensor params; resizable_tensor params;
dlib::rand rnd;
}; };
template <typename SUBNET> template <typename SUBNET>
using multiply = add_layer<multiply_, SUBNET>; using relu = add_layer<relu_, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -218,6 +218,12 @@ namespace dlib ...@@ -218,6 +218,12 @@ namespace dlib
}; };
void serialize(const EXAMPLE_LAYER_& item, std::ostream& out);
void deserialize(EXAMPLE_LAYER_& item, std::istream& in);
/*!
provides serialization support
!*/
// For each layer you define, always define an add_layer template so that layers can be // For each layer you define, always define an add_layer template so that layers can be
// easily composed. Moreover, the convention is that the layer class ends with an _ // easily composed. Moreover, the convention is that the layer class ends with an _
// while the add_layer template has the same name but without the trailing _. // while the add_layer template has the same name but without the trailing _.
...@@ -274,6 +280,11 @@ namespace dlib ...@@ -274,6 +280,11 @@ namespace dlib
!*/ !*/
}; };
void serialize(const fc_& item, std::ostream& out);
void deserialize(fc_& item, std::istream& in);
/*!
provides serialization support
!*/
template <typename SUBNET> template <typename SUBNET>
using fc = add_layer<fc_, SUBNET>; using fc = add_layer<fc_, SUBNET>;
...@@ -306,6 +317,11 @@ namespace dlib ...@@ -306,6 +317,11 @@ namespace dlib
!*/ !*/
}; };
void serialize(const relu_& item, std::ostream& out);
void deserialize(relu_& item, std::istream& in);
/*!
provides serialization support
!*/
template <typename SUBNET> template <typename SUBNET>
using relu = add_layer<relu_, SUBNET>; using relu = add_layer<relu_, SUBNET>;
......
...@@ -81,6 +81,19 @@ namespace dlib ...@@ -81,6 +81,19 @@ namespace dlib
return loss; return loss;
} }
friend void serialize(const loss_binary_hinge_& item, std::ostream& out)
{
serialize("loss_binary_hinge_", out);
}
friend void deserialize(loss_binary_hinge_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "loss_binary_hinge_")
throw serialization_error("Unexpected version found while deserializing dlib::loss_binary_hinge_.");
}
}; };
template <typename SUBNET> template <typename SUBNET>
...@@ -105,6 +118,19 @@ namespace dlib ...@@ -105,6 +118,19 @@ namespace dlib
return 0; return 0;
} }
friend void serialize(const loss_no_label_& item, std::ostream& out)
{
serialize("loss_no_label_", out);
}
friend void deserialize(loss_no_label_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "loss_no_label_")
throw serialization_error("Unexpected version found while deserializing dlib::loss_no_label_.");
}
}; };
template <typename SUBNET> template <typename SUBNET>
......
...@@ -118,6 +118,12 @@ namespace dlib ...@@ -118,6 +118,12 @@ namespace dlib
!*/ !*/
}; };
void serialize(const EXAMPLE_LOSS_LAYER_& item, std::ostream& out);
void deserialize(EXAMPLE_LOSS_LAYER_& item, std::istream& in);
/*!
provides serialization support
!*/
// For each loss layer you define, always define an add_loss_layer template so that // For each loss layer you define, always define an add_loss_layer template so that
// layers can be easily composed. Moreover, the convention is that the layer class // layers can be easily composed. Moreover, the convention is that the layer class
// ends with an _ while the add_loss_layer template has the same name but without the // ends with an _ while the add_loss_layer template has the same name but without the
...@@ -179,6 +185,12 @@ namespace dlib ...@@ -179,6 +185,12 @@ namespace dlib
}; };
void serialize(const loss_binary_hinge_& item, std::ostream& out);
void deserialize(loss_binary_hinge_& item, std::istream& in);
/*!
provides serialization support
!*/
template <typename SUBNET> template <typename SUBNET>
using loss_binary_hinge = add_loss_layer<loss_binary_hinge_, SUBNET>; using loss_binary_hinge = add_loss_layer<loss_binary_hinge_, SUBNET>;
......
...@@ -48,6 +48,27 @@ namespace dlib ...@@ -48,6 +48,27 @@ namespace dlib
l.get_layer_params() += v; l.get_layer_params() += v;
} }
friend void serialize(const sgd& item, std::ostream& out)
{
serialize("sgd", out);
serialize(item.v, out);
serialize(item.weight_decay, out);
serialize(item.learning_rate, out);
serialize(item.momentum, out);
}
friend void deserialize(sgd& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "sgd")
throw serialization_error("Unexpected version found while deserializing dlib::sgd.");
deserialize(item.v, in);
deserialize(item.weight_decay, in);
deserialize(item.learning_rate, in);
deserialize(item.momentum, in);
}
private: private:
matrix<float> v; matrix<float> v;
float weight_decay; float weight_decay;
......
...@@ -52,6 +52,12 @@ namespace dlib ...@@ -52,6 +52,12 @@ namespace dlib
!*/ !*/
}; };
void serialize(const EXAMPLE_SOLVER& item, std::ostream& out);
void deserialize(EXAMPLE_SOLVER& item, std::istream& in);
/*!
provides serialization support
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -92,6 +98,12 @@ namespace dlib ...@@ -92,6 +98,12 @@ namespace dlib
float get_momentum () const; float get_momentum () const;
}; };
void serialize(const sgd& item, std::ostream& out);
void deserialize(sgd& item, std::istream& in);
/*!
provides serialization support
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
...@@ -112,6 +112,7 @@ namespace dlib ...@@ -112,6 +112,7 @@ namespace dlib
size_t size() const { return data_size; } size_t size() const { return data_size; }
private: private:
void copy_to_device() const void copy_to_device() const
...@@ -144,6 +145,30 @@ namespace dlib ...@@ -144,6 +145,30 @@ namespace dlib
std::unique_ptr<float[]> data_device; std::unique_ptr<float[]> data_device;
}; };
inline void serialize(const gpu_data& item, std::ostream& out)
{
int version = 1;
serialize(item.size(), out);
auto data = item.host();
for (size_t i = 0; i < item.size(); ++i)
serialize(data[i], out);
}
inline void deserialize(gpu_data& item, std::istream& in)
{
int version;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::gpu_data.");
size_t s;
deserialize(s, in);
item.set_size(s);
auto data = item.host();
for (size_t i = 0; i < item.size(); ++i)
deserialize(data[i], in);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class tensor class tensor
...@@ -466,6 +491,37 @@ namespace dlib ...@@ -466,6 +491,37 @@ namespace dlib
} }
}; };
inline void serialize(const tensor& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.num_samples(), out);
serialize(item.nr(), out);
serialize(item.nc(), out);
serialize(item.k(), out);
auto data = item.host();
for (size_t i = 0; i < item.size(); ++i)
serialize(data[i], out);
}
inline void deserialize(resizable_tensor& item, std::istream& in)
{
int version;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::resizable_tensor.");
long num_samples=0, nr=0, nc=0, k=0;
deserialize(num_samples, in);
deserialize(nr, in);
deserialize(nc, in);
deserialize(k, in);
item.set_size(num_samples, nr, nc, k);
auto data = item.host();
for (size_t i = 0; i < item.size(); ++i)
deserialize(data[i], in);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
inline double dot( inline double dot(
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "../statistics.h" #include "../statistics.h"
#include "../console_progress_indicator.h" #include "../console_progress_indicator.h"
#include <chrono> #include <chrono>
#include "../serialize.h"
namespace dlib namespace dlib
{ {
...@@ -281,8 +282,34 @@ namespace dlib ...@@ -281,8 +282,34 @@ namespace dlib
return net; return net;
} }
friend void serialize(const dnn_trainer& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.num_epochs, out);
serialize(item.mini_batch_size, out);
serialize(item.verbose, out);
serialize(item.net, out);
serialize(item.solvers, out);
}
friend void deserialize(dnn_trainer& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
deserialize(item.num_epochs, in);
deserialize(item.mini_batch_size, in);
deserialize(item.verbose, in);
deserialize(item.net, in);
deserialize(item.solvers, in);
}
private: private:
const static long string_pad = 10;
void init() void init()
{ {
num_epochs = 300; num_epochs = 300;
...@@ -293,7 +320,6 @@ namespace dlib ...@@ -293,7 +320,6 @@ namespace dlib
unsigned long num_epochs; unsigned long num_epochs;
unsigned long mini_batch_size; unsigned long mini_batch_size;
bool verbose; bool verbose;
const static long string_pad = 10;
net_type net; net_type net;
sstack<solver_type,net_type::num_layers> solvers; sstack<solver_type,net_type::num_layers> solvers;
......
...@@ -222,6 +222,14 @@ namespace dlib ...@@ -222,6 +222,14 @@ namespace dlib
}; };
template <typename T, typename U>
void serialize(const dnn_trainer<T,U>& item, std::ostream& out);
template <typename T, typename U>
void deserialize(dnn_trainer<T,U>& item, std::istream& in);
/*!
provides serialization support
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment