Commit 99ce564b authored by Davis King's avatar Davis King

Changed the tensor serialization format to use a 4 byte little endian IEEE

representation rather than dlib's default variable length encoding.  This
change makes the resulting serialized networks about 33% smaller.
parent 9d5c2b74
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "../matrix.h" #include "../matrix.h"
#include "cudnn_dlibapi.h" #include "cudnn_dlibapi.h"
#include "gpu_data.h" #include "gpu_data.h"
#include "../byte_orderer.h"
#include <memory> #include <memory>
namespace dlib namespace dlib
...@@ -378,21 +379,33 @@ namespace dlib ...@@ -378,21 +379,33 @@ namespace dlib
inline void serialize(const tensor& item, std::ostream& out) inline void serialize(const tensor& item, std::ostream& out)
{ {
int version = 1; int version = 2;
serialize(version, out); serialize(version, out);
serialize(item.num_samples(), out); serialize(item.num_samples(), out);
serialize(item.k(), out); serialize(item.k(), out);
serialize(item.nr(), out); serialize(item.nr(), out);
serialize(item.nc(), out); serialize(item.nc(), out);
for (auto& d : item) byte_orderer bo;
serialize(d, out); auto sbuf = out.rdbuf();
for (auto d : item)
{
// Write out our data as 4byte little endian IEEE floats rather than using
// dlib's default float serialization. We do this because it will result in
// more compact outputs. It's slightly less portable but it seems doubtful
// that any CUDA enabled platform isn't going to use IEEE floats. But if one
// does we can just update the serialization code here to handle it if such a
// platform is encountered.
bo.host_to_little(d);
static_assert(sizeof(d)==4, "This serialization code assumes we are writing 4 byte floats");
sbuf->sputn((char*)&d, sizeof(d));
}
} }
inline void deserialize(resizable_tensor& item, std::istream& in) inline void deserialize(resizable_tensor& item, std::istream& in)
{ {
int version; int version;
deserialize(version, in); deserialize(version, in);
if (version != 1) if (version != 2)
throw serialization_error("Unexpected version found while deserializing dlib::resizable_tensor."); throw serialization_error("Unexpected version found while deserializing dlib::resizable_tensor.");
long num_samples=0, k=0, nr=0, nc=0; long num_samples=0, k=0, nr=0, nc=0;
...@@ -401,8 +414,18 @@ namespace dlib ...@@ -401,8 +414,18 @@ namespace dlib
deserialize(nr, in); deserialize(nr, in);
deserialize(nc, in); deserialize(nc, in);
item.set_size(num_samples, k, nr, nc); item.set_size(num_samples, k, nr, nc);
byte_orderer bo;
auto sbuf = in.rdbuf();
for (auto& d : item) for (auto& d : item)
deserialize(d, in); {
static_assert(sizeof(d)==4, "This serialization code assumes we are writing 4 byte floats");
if (sbuf->sgetn((char*)&d,sizeof(d)) != sizeof(d))
{
in.setstate(std::ios::badbit);
throw serialization_error("Error reading data while deserializing dlib::resizable_tensor.");
}
bo.little_to_host(d);
}
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment