Commit 247beb3d authored by Davis King's avatar Davis King

Removed support for old serialization formats in many of the DNN objects. This

is to clean up the code since it was getting somewhat complex and this is the
last opportunity to do this kind of cleanup prior to the release of dlib v19.0.

If you have saved network objects and want to convert them to the current
format, then make sure you checkout the previous commit (labeled with tag
before_dnn_serialization_cleanup) and then deserialize and serialize your
network back to disk.
parent fc7f9b6c
......@@ -209,80 +209,11 @@ namespace dlib
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
return;
}
if (version == "con_")
{
deserialize(item.params, in);
deserialize(num_filters, in);
deserialize(nr, in);
deserialize(nc, in);
deserialize(stride_y, in);
deserialize(stride_x, in);
deserialize(item.filters, in);
deserialize(item.biases, in);
item.padding_y_ = nr/2;
item.padding_x_ = nc/2;
}
else if (version == "con_2" || version == "con_3")
{
deserialize(item.params, in);
deserialize(num_filters, in);
deserialize(nr, in);
deserialize(nc, in);
deserialize(stride_y, in);
deserialize(stride_x, in);
deserialize(item.padding_y_, in);
deserialize(item.padding_x_, in);
deserialize(item.filters, in);
deserialize(item.biases, in);
if (version == "con_3")
{
deserialize(item.learning_rate_multiplier, in);
deserialize(item.weight_decay_multiplier, in);
deserialize(item.bias_learning_rate_multiplier, in);
deserialize(item.bias_weight_decay_multiplier, in);
}
else
{
// Previous versions didn't have these parameters, so they were
// implicitly 1.
item.learning_rate_multiplier = 1;
item.weight_decay_multiplier = 1;
item.bias_learning_rate_multiplier = 1;
item.bias_weight_decay_multiplier = 1;
}
if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_");
if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_");
}
else
{
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_.");
}
// now flip all the filters
alias_tensor at(_nr, _nc);
size_t off = 0;
for (int i = 0; i < item.filters.num_samples(); ++i)
{
for (int j = 0; j < item.filters.k(); ++j)
{
auto temp = at(item.params,off);
off += _nr*_nc;
temp = flipud(fliplr(mat(temp)));
}
}
if (num_filters != _num_filters) throw serialization_error("Wrong num_filters found while deserializing dlib::con_");
if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_");
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
}
......@@ -458,16 +389,7 @@ namespace dlib
long nc;
int stride_y;
int stride_x;
if (version == "max_pool_")
{
deserialize(nr, in);
deserialize(nc, in);
deserialize(stride_y, in);
deserialize(stride_x, in);
item.padding_y_ = nr/2;
item.padding_x_ = nc/2;
}
else if (version == "max_pool_2")
if (version == "max_pool_2")
{
deserialize(nr, in);
deserialize(nc, in);
......@@ -475,14 +397,14 @@ namespace dlib
deserialize(stride_x, in);
deserialize(item.padding_y_, in);
deserialize(item.padding_x_, in);
if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::max_pool_");
if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::max_pool_");
}
else
{
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::max_pool_.");
}
if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::max_pool_");
if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::max_pool_");
if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::max_pool_");
if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::max_pool_");
if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::max_pool_");
......@@ -647,16 +569,7 @@ namespace dlib
long nc;
int stride_y;
int stride_x;
if (version == "avg_pool_")
{
deserialize(nr, in);
deserialize(nc, in);
deserialize(stride_y, in);
deserialize(stride_x, in);
item.padding_y_ = nr/2;
item.padding_x_ = nc/2;
}
else if (version == "avg_pool_2")
if (version == "avg_pool_2")
{
deserialize(nr, in);
deserialize(nc, in);
......@@ -664,14 +577,14 @@ namespace dlib
deserialize(stride_x, in);
deserialize(item.padding_y_, in);
deserialize(item.padding_x_, in);
if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::avg_pool_");
if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::avg_pool_");
}
else
{
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::avg_pool_.");
}
if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::avg_pool_");
if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::avg_pool_");
if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::avg_pool_");
if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::avg_pool_");
if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::avg_pool_");
......@@ -864,19 +777,16 @@ namespace dlib
{
std::string version;
deserialize(version, in);
if (version != "bn_")
{
if (mode == CONV_MODE)
{
if (version != "bn_con" && version != "bn_con2")
if (version != "bn_con2")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
}
else // must be in FC_MODE
{
if (version != "bn_fc" && version != "bn_fc2")
if (version != "bn_fc2")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
}
}
deserialize(item.params, in);
deserialize(item.gamma, in);
......@@ -887,37 +797,12 @@ namespace dlib
deserialize(item.running_variances, in);
deserialize(item.num_updates, in);
deserialize(item.running_stats_window_size, in);
// if this is the older "bn_" version then check its saved mode value and make
// sure it is the one we are really using.
if (version == "bn_")
{
int _mode;
deserialize(_mode, in);
if (mode != (layer_mode)_mode) throw serialization_error("Wrong mode found while deserializing dlib::bn_");
// We also need to flip the running_variances around since the previous
// format saved the inverse standard deviations instead of variances.
item.running_variances = 1.0f/squared(mat(item.running_variances)) - DEFAULT_BATCH_NORM_EPS;
}
else if (version == "bn_con2" || version == "bn_fc2")
{
deserialize(item.learning_rate_multiplier, in);
deserialize(item.weight_decay_multiplier, in);
deserialize(item.bias_learning_rate_multiplier, in);
deserialize(item.bias_weight_decay_multiplier, in);
deserialize(item.eps, in);
}
else
{
// Previous versions didn't have these parameters, so they were
// implicitly 1.
item.learning_rate_multiplier = 1;
item.weight_decay_multiplier = 1;
item.eps = DEFAULT_BATCH_NORM_EPS;
}
}
friend std::ostream& operator<<(std::ostream& out, const bn_& item)
{
......@@ -1106,7 +991,7 @@ namespace dlib
{
std::string version;
deserialize(version, in);
if (version != "fc_" && version != "fc_2")
if (version != "fc_2")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::fc_.");
deserialize(item.num_outputs, in);
......@@ -1117,23 +1002,11 @@ namespace dlib
int bmode = 0;
deserialize(bmode, in);
if (bias_mode != (fc_bias_mode)bmode) throw serialization_error("Wrong fc_bias_mode found while deserializing dlib::fc_");
if (version == "fc_2")
{
deserialize(item.learning_rate_multiplier, in);
deserialize(item.weight_decay_multiplier, in);
deserialize(item.bias_learning_rate_multiplier, in);
deserialize(item.bias_weight_decay_multiplier, in);
}
else
{
// Previous versions didn't have these parameters, so they were
// implicitly 1.
item.learning_rate_multiplier = 1;
item.weight_decay_multiplier = 1;
item.bias_learning_rate_multiplier = 1;
item.bias_weight_decay_multiplier = 1;
}
}
friend std::ostream& operator<<(std::ostream& out, const fc_& item)
{
......@@ -1525,7 +1398,7 @@ namespace dlib
{
std::string version;
deserialize(version, in);
if (version == "bn_con" || version == "bn_con2")
if (version == "bn_con2")
{
// Since we can build an affine_ from a bn_ we check if that's what is in
// the stream and if so then just convert it right here.
......@@ -1535,7 +1408,7 @@ namespace dlib
item = temp;
return;
}
else if (version == "bn_fc" || version == "bn_fc2")
else if (version == "bn_fc2")
{
// Since we can build an affine_ from a bn_ we check if that's what is in
// the stream and if so then just convert it right here.
......
......@@ -731,7 +731,7 @@ namespace dlib
item.wait_for_thread_to_pause();
int version = 0;
deserialize(version, in);
if (version != 6 && version != 7)
if (version != 7)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
size_t num_layers = 0;
......@@ -761,16 +761,8 @@ namespace dlib
deserialize(item.epoch_iteration, in);
deserialize(item.epoch_pos, in);
deserialize(item.train_one_step_calls, in);
if (version == 7)
{
deserialize(item.lr_schedule, in);
deserialize(item.lr_schedule_pos, in);
}
else
{
item.lr_schedule.set_size(0);
item.lr_schedule_pos = 0;
}
if (item.devices.size() > 1)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment