Commit 6c36592c authored by Davis King's avatar Davis King

Added serialization support to everything.

parent e679d66a
......@@ -67,6 +67,18 @@ namespace dlib
const sstack<T,N-1>& pop() const { return data; }
sstack<T,N-1>& pop() { return data; }
friend void serialize(const sstack& item, std::ostream& out)
{
serialize(item.top(), out);
serialize(item.pop(), out);
}
friend void deserialize(sstack& item, std::istream& in)
{
deserialize(item.top(), in);
deserialize(item.pop(), in);
}
private:
T item;
sstack<T,N-1> data;
......@@ -83,6 +95,17 @@ namespace dlib
T& top() { return item; }
size_t size() const { return 1; }
friend void serialize(const sstack& item, std::ostream& out)
{
serialize(item.top(), out);
}
friend void deserialize(sstack& item, std::istream& in)
{
deserialize(item.top(), in);
}
private:
T item;
};
......@@ -294,6 +317,32 @@ namespace dlib
subnetwork.clean();
}
friend void serialize(const add_layer& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.subnetwork, out);
serialize(item.details, out);
serialize(item.this_layer_setup_called, out);
serialize(item.gradient_input_is_stale, out);
serialize(item.x_grad, out);
serialize(item.cached_output, out);
}
friend void deserialize(add_layer& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::add_layer.");
deserialize(item.subnetwork, in);
deserialize(item.details, in);
deserialize(item.this_layer_setup_called, in);
deserialize(item.gradient_input_is_stale, in);
deserialize(item.x_grad, in);
deserialize(item.cached_output, in);
}
private:
......@@ -468,6 +517,32 @@ namespace dlib
gradient_input_is_stale = true;
}
friend void serialize(const add_layer& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.input_layer, out);
serialize(item.details, out);
serialize(item.this_layer_setup_called, out);
serialize(item.gradient_input_is_stale, out);
serialize(item.x_grad, out);
serialize(item.cached_output, out);
}
friend void deserialize(add_layer& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::add_layer.");
deserialize(item.input_layer, in);
deserialize(item.details, in);
deserialize(item.this_layer_setup_called, in);
deserialize(item.gradient_input_is_stale, in);
deserialize(item.x_grad, in);
deserialize(item.cached_output, in);
}
private:
class subnet_wrapper
......@@ -601,6 +676,22 @@ namespace dlib
subnetwork.clean();
}
friend void serialize(const add_tag_layer& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.subnetwork, out);
}
friend void deserialize(add_tag_layer& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::add_tag_layer.");
deserialize(item.subnetwork, in);
}
private:
subnet_type subnetwork;
......@@ -702,6 +793,26 @@ namespace dlib
cached_output.clear();
}
friend void serialize(const add_tag_layer& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.input_layer, out);
serialize(item.cached_output, out);
serialize(item.grad_final_ignored, out);
}
friend void deserialize(add_tag_layer& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::add_tag_layer.");
deserialize(item.input_layer, in);
deserialize(item.cached_output, in);
deserialize(item.grad_final_ignored, in);
}
private:
subnet_type input_layer;
......@@ -759,7 +870,8 @@ namespace dlib
const static unsigned int sample_expansion_factor = subnet_type::sample_expansion_factor;
typedef typename get_loss_layer_label_type<LOSS_DETAILS>::type label_type;
static_assert(is_nonloss_layer_type<SUBNET>::value, "SUBNET must be of type add_layer, add_skip_layer, or add_tag_layer.");
static_assert(is_nonloss_layer_type<SUBNET>::value,
"SUBNET must be of type add_layer, add_skip_layer, or add_tag_layer.");
static_assert(sample_expansion_factor == LOSS_DETAILS::sample_expansion_factor,
"The loss layer and input layer must agree on the sample_expansion_factor.");
......@@ -947,6 +1059,24 @@ namespace dlib
subnetwork.clear();
}
friend void serialize(const add_loss_layer& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.loss, out);
serialize(item.subnetwork, out);
}
friend void deserialize(add_loss_layer& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::add_loss_layer.");
deserialize(item.loss, in);
deserialize(item.subnetwork, in);
}
private:
loss_details_type loss;
......@@ -1150,6 +1280,22 @@ namespace dlib
subnetwork.clean();
}
friend void serialize(const add_skip_layer& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.subnetwork, out);
}
friend void deserialize(add_skip_layer& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::add_skip_layer.");
deserialize(item.subnetwork, in);
}
private:
subnet_type subnetwork;
......
......@@ -119,6 +119,12 @@ namespace dlib
!*/
};
void serialize(const sstack& item, std::ostream& out);
void deserialize(sstack& item, std::istream& in);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
template <
......@@ -378,6 +384,14 @@ namespace dlib
};
template <typename T, typename U>,
void serialize(const add_layer<T,U>& item, std::ostream& out);
template <typename T, typename U>,
void deserialize(add_layer<T,U>& item, std::istream& in);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......@@ -769,6 +783,14 @@ namespace dlib
!*/
};
template <typename T, typename U>,
void serialize(const add_loss_layer<T,U>& item, std::ostream& out);
template <typename T, typename U>,
void deserialize(add_loss_layer<T,U>& item, std::istream& in);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......@@ -799,6 +821,14 @@ namespace dlib
!*/
};
template <unsigned long ID, typename U>,
void serialize(const add_tag_layer<ID,U>& item, std::ostream& out);
template <unsigned long ID, typename U>,
void deserialize(add_tag_layer<ID,U>& item, std::istream& in);
/*!
provides serialization support
!*/
template <typename SUBNET> using tag1 = add_tag_layer< 1, SUBNET>;
template <typename SUBNET> using tag2 = add_tag_layer< 2, SUBNET>;
template <typename SUBNET> using tag3 = add_tag_layer< 3, SUBNET>;
......@@ -834,6 +864,14 @@ namespace dlib
!*/
};
template <template<typename> class T, typename U>
void serialize(const add_skip_layer<T,U>& item, std::ostream& out);
template <template<typename> class T, typename U>
void deserialize(add_skip_layer<T,U>& item, std::istream& in);
/*!
provides serialization support
!*/
template <typename SUBNET> using skip1 = add_skip_layer< tag1, SUBNET>;
template <typename SUBNET> using skip2 = add_skip_layer< tag2, SUBNET>;
template <typename SUBNET> using skip3 = add_skip_layer< tag3, SUBNET>;
......
......@@ -73,6 +73,20 @@ namespace dlib
}
}
friend void serialize(const input& item, std::ostream& out)
{
serialize("input<matrix>", out);
}
friend void deserialize(input& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "input<matrix>")
throw serialization_error("Unexpected version found while deserializing dlib::input.");
}
};
// ----------------------------------------------------------------------------------------
......@@ -126,6 +140,20 @@ namespace dlib
}
}
friend void serialize(const input& item, std::ostream& out)
{
serialize("input<array2d>", out);
}
friend void deserialize(input& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "input<array2d>")
throw serialization_error("Unexpected version found while deserializing dlib::input.");
}
};
// ----------------------------------------------------------------------------------------
......
......@@ -86,6 +86,12 @@ namespace dlib
!*/
};
void serialize(const EXAMPLE_INPUT_LAYER& item, std::ostream& out);
void deserialize(EXAMPLE_INPUT_LAYER& item, std::istream& in);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
template <
......@@ -132,6 +138,14 @@ namespace dlib
!*/
};
template <typename T>
void serialize(const input<T>& item, std::ostream& out);
template <typename T>
void deserialize(input<T>& item, std::istream& in);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
}
......
......@@ -59,13 +59,12 @@ namespace dlib
public:
fc_() : num_outputs(1)
{
rnd.set_seed("fc_" + cast_to_string(num_outputs));
}
explicit fc_(unsigned long num_outputs_)
explicit fc_(
unsigned long num_outputs_
) : num_outputs(num_outputs_)
{
num_outputs = num_outputs_;
rnd.set_seed("fc_" + cast_to_string(num_outputs));
}
unsigned long get_num_outputs (
......@@ -77,6 +76,7 @@ namespace dlib
num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k();
params.set_size(num_inputs, num_outputs);
dlib::rand rnd("fc_"+cast_to_string(num_outputs));
randomize_parameters(params, num_inputs+num_outputs, rnd);
}
......@@ -101,12 +101,30 @@ namespace dlib
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
friend void serialize(const fc_& item, std::ostream& out)
{
serialize("fc_", out);
serialize(item.num_outputs, out);
serialize(item.num_inputs, out);
serialize(item.params, out);
}
friend void deserialize(fc_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "fc_")
throw serialization_error("Unexpected version found while deserializing dlib::fc_.");
deserialize(item.num_outputs, in);
deserialize(item.num_inputs, in);
deserialize(item.params, in);
}
private:
unsigned long num_outputs;
unsigned long num_inputs;
resizable_tensor params;
dlib::rand rnd;
};
......@@ -151,81 +169,28 @@ namespace dlib
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
private:
resizable_tensor params;
};
template <typename SUBNET>
using relu = add_layer<relu_, SUBNET>;
// ----------------------------------------------------------------------------------------
class multiply_
{
public:
multiply_()
{
}
template <typename SUBNET>
void setup (const SUBNET& sub)
{
num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k();
params.set_size(1, num_inputs);
std::cout << "multiply_::setup() " << params.size() << std::endl;
const int num_outputs = num_inputs;
randomize_parameters(params, num_inputs+num_outputs, rnd);
}
template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output)
{
DLIB_CASSERT( sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k() == params.size(), "");
DLIB_CASSERT( sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k() == num_inputs, "");
output.copy_size(sub.get_output());
auto indata = sub.get_output().host();
auto outdata = output.host();
auto paramdata = params.host();
for (int i = 0; i < sub.get_output().num_samples(); ++i)
friend void serialize(const relu_& item, std::ostream& out)
{
for (int j = 0; j < num_inputs; ++j)
{
*outdata++ = *indata++ * paramdata[j];
}
}
serialize("relu_", out);
}
template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
{
params_grad += sum_rows(pointwise_multiply(mat(sub.get_output()),mat(gradient_input)));
for (long i = 0; i < gradient_input.num_samples(); ++i)
friend void deserialize(relu_& item, std::istream& in)
{
sub.get_gradient_input().add_to_sample(i,
pointwise_multiply(rowm(mat(gradient_input),i), mat(params)));
}
std::string version;
deserialize(version, in);
if (version != "relu_")
throw serialization_error("Unexpected version found while deserializing dlib::relu_.");
}
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
private:
int num_inputs;
resizable_tensor params;
dlib::rand rnd;
};
template <typename SUBNET>
using multiply = add_layer<multiply_, SUBNET>;
using relu = add_layer<relu_, SUBNET>;
// ----------------------------------------------------------------------------------------
......
......@@ -218,6 +218,12 @@ namespace dlib
};
void serialize(const EXAMPLE_LAYER_& item, std::ostream& out);
void deserialize(EXAMPLE_LAYER_& item, std::istream& in);
/*!
provides serialization support
!*/
// For each layer you define, always define an add_layer template so that layers can be
// easily composed. Moreover, the convention is that the layer class ends with an _
// while the add_layer template has the same name but without the trailing _.
......@@ -274,6 +280,11 @@ namespace dlib
!*/
};
void serialize(const fc_& item, std::ostream& out);
void deserialize(fc_& item, std::istream& in);
/*!
provides serialization support
!*/
template <typename SUBNET>
using fc = add_layer<fc_, SUBNET>;
......@@ -306,6 +317,11 @@ namespace dlib
!*/
};
void serialize(const relu_& item, std::ostream& out);
void deserialize(relu_& item, std::istream& in);
/*!
provides serialization support
!*/
template <typename SUBNET>
using relu = add_layer<relu_, SUBNET>;
......
......@@ -81,6 +81,19 @@ namespace dlib
return loss;
}
friend void serialize(const loss_binary_hinge_& item, std::ostream& out)
{
serialize("loss_binary_hinge_", out);
}
friend void deserialize(loss_binary_hinge_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "loss_binary_hinge_")
throw serialization_error("Unexpected version found while deserializing dlib::loss_binary_hinge_.");
}
};
template <typename SUBNET>
......@@ -105,6 +118,19 @@ namespace dlib
return 0;
}
friend void serialize(const loss_no_label_& item, std::ostream& out)
{
serialize("loss_no_label_", out);
}
friend void deserialize(loss_no_label_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "loss_no_label_")
throw serialization_error("Unexpected version found while deserializing dlib::loss_no_label_.");
}
};
template <typename SUBNET>
......
......@@ -118,6 +118,12 @@ namespace dlib
!*/
};
void serialize(const EXAMPLE_LOSS_LAYER_& item, std::ostream& out);
void deserialize(EXAMPLE_LOSS_LAYER_& item, std::istream& in);
/*!
provides serialization support
!*/
// For each loss layer you define, always define an add_loss_layer template so that
// layers can be easily composed. Moreover, the convention is that the layer class
// ends with an _ while the add_loss_layer template has the same name but without the
......@@ -179,6 +185,12 @@ namespace dlib
};
void serialize(const loss_binary_hinge_& item, std::ostream& out);
void deserialize(loss_binary_hinge_& item, std::istream& in);
/*!
provides serialization support
!*/
template <typename SUBNET>
using loss_binary_hinge = add_loss_layer<loss_binary_hinge_, SUBNET>;
......
......@@ -48,6 +48,27 @@ namespace dlib
l.get_layer_params() += v;
}
friend void serialize(const sgd& item, std::ostream& out)
{
serialize("sgd", out);
serialize(item.v, out);
serialize(item.weight_decay, out);
serialize(item.learning_rate, out);
serialize(item.momentum, out);
}
friend void deserialize(sgd& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "sgd")
throw serialization_error("Unexpected version found while deserializing dlib::sgd.");
deserialize(item.v, in);
deserialize(item.weight_decay, in);
deserialize(item.learning_rate, in);
deserialize(item.momentum, in);
}
private:
matrix<float> v;
float weight_decay;
......
......@@ -52,6 +52,12 @@ namespace dlib
!*/
};
void serialize(const EXAMPLE_SOLVER& item, std::ostream& out);
void deserialize(EXAMPLE_SOLVER& item, std::istream& in);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......@@ -92,6 +98,12 @@ namespace dlib
float get_momentum () const;
};
void serialize(const sgd& item, std::ostream& out);
void deserialize(sgd& item, std::istream& in);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
}
......
......@@ -112,6 +112,7 @@ namespace dlib
size_t size() const { return data_size; }
private:
void copy_to_device() const
......@@ -144,6 +145,30 @@ namespace dlib
std::unique_ptr<float[]> data_device;
};
inline void serialize(const gpu_data& item, std::ostream& out)
{
int version = 1;
serialize(item.size(), out);
auto data = item.host();
for (size_t i = 0; i < item.size(); ++i)
serialize(data[i], out);
}
inline void deserialize(gpu_data& item, std::istream& in)
{
int version;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::gpu_data.");
size_t s;
deserialize(s, in);
item.set_size(s);
auto data = item.host();
for (size_t i = 0; i < item.size(); ++i)
deserialize(data[i], in);
}
// ----------------------------------------------------------------------------------------
class tensor
......@@ -466,6 +491,37 @@ namespace dlib
}
};
inline void serialize(const tensor& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.num_samples(), out);
serialize(item.nr(), out);
serialize(item.nc(), out);
serialize(item.k(), out);
auto data = item.host();
for (size_t i = 0; i < item.size(); ++i)
serialize(data[i], out);
}
inline void deserialize(resizable_tensor& item, std::istream& in)
{
int version;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::resizable_tensor.");
long num_samples=0, nr=0, nc=0, k=0;
deserialize(num_samples, in);
deserialize(nr, in);
deserialize(nc, in);
deserialize(k, in);
item.set_size(num_samples, nr, nc, k);
auto data = item.host();
for (size_t i = 0; i < item.size(); ++i)
deserialize(data[i], in);
}
// ----------------------------------------------------------------------------------------
inline double dot(
......
......@@ -9,6 +9,7 @@
#include "../statistics.h"
#include "../console_progress_indicator.h"
#include <chrono>
#include "../serialize.h"
namespace dlib
{
......@@ -281,8 +282,34 @@ namespace dlib
return net;
}
friend void serialize(const dnn_trainer& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.num_epochs, out);
serialize(item.mini_batch_size, out);
serialize(item.verbose, out);
serialize(item.net, out);
serialize(item.solvers, out);
}
friend void deserialize(dnn_trainer& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
deserialize(item.num_epochs, in);
deserialize(item.mini_batch_size, in);
deserialize(item.verbose, in);
deserialize(item.net, in);
deserialize(item.solvers, in);
}
private:
const static long string_pad = 10;
void init()
{
num_epochs = 300;
......@@ -293,7 +320,6 @@ namespace dlib
unsigned long num_epochs;
unsigned long mini_batch_size;
bool verbose;
const static long string_pad = 10;
net_type net;
sstack<solver_type,net_type::num_layers> solvers;
......
......@@ -222,6 +222,14 @@ namespace dlib
};
template <typename T, typename U>
void serialize(const dnn_trainer<T,U>& item, std::ostream& out);
template <typename T, typename U>
void deserialize(dnn_trainer<T,U>& item, std::istream& in);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment