Commit fe168596 authored by Davis King's avatar Davis King

Moved most of the layer parameters from runtime variables set in constructors

to template arguments.  This way, the type of a network specifies the entire
network architecture and most of the time the user doesn't even need to do
anything with layer constructors.
parent 001bca78
...@@ -19,31 +19,25 @@ namespace dlib ...@@ -19,31 +19,25 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <
long _num_filters,
long _nr,
long _nc,
int _stride_y,
int _stride_x
>
class con_ class con_
{ {
public: public:
con_ ( static_assert(_num_filters > 0, "The number of filters must be > 0");
) : static_assert(_nr > 0, "The number of rows in a filter must be > 0");
_num_filters(1), static_assert(_nc > 0, "The number of columns in a filter must be > 0");
_nr(3), static_assert(_stride_y > 0, "The filter stride must be > 0");
_nc(3), static_assert(_stride_x > 0, "The filter stride must be > 0");
_stride_y(1),
_stride_x(1)
{}
con_( con_(
long num_filters_, )
long nr_,
long nc_,
int stride_y_ = 1,
int stride_x_ = 1
) :
_num_filters(num_filters_),
_nr(nr_),
_nc(nc_),
_stride_y(stride_y_),
_stride_x(stride_x_)
{} {}
long num_filters() const { return _num_filters; } long num_filters() const { return _num_filters; }
...@@ -56,11 +50,6 @@ namespace dlib ...@@ -56,11 +50,6 @@ namespace dlib
const con_& item const con_& item
) : ) :
params(item.params), params(item.params),
_num_filters(item._num_filters),
_nr(item._nr),
_nc(item._nc),
_stride_y(item._stride_y),
_stride_x(item._stride_x),
filters(item.filters), filters(item.filters),
biases(item.biases) biases(item.biases)
{ {
...@@ -78,11 +67,6 @@ namespace dlib ...@@ -78,11 +67,6 @@ namespace dlib
// this->conv is non-copyable and basically stateless, so we have to write our // this->conv is non-copyable and basically stateless, so we have to write our
// own copy to avoid trying to copy it and getting an error. // own copy to avoid trying to copy it and getting an error.
params = item.params; params = item.params;
_num_filters = item._num_filters;
_nr = item._nr;
_nc = item._nc;
_stride_y = item._stride_y;
_stride_x = item._stride_x;
filters = item.filters; filters = item.filters;
biases = item.biases; biases = item.biases;
return *this; return *this;
...@@ -135,11 +119,11 @@ namespace dlib ...@@ -135,11 +119,11 @@ namespace dlib
{ {
serialize("con_", out); serialize("con_", out);
serialize(item.params, out); serialize(item.params, out);
serialize(item._num_filters, out); serialize(_num_filters, out);
serialize(item._nr, out); serialize(_nr, out);
serialize(item._nc, out); serialize(_nc, out);
serialize(item._stride_y, out); serialize(_stride_y, out);
serialize(item._stride_x, out); serialize(_stride_x, out);
serialize(item.filters, out); serialize(item.filters, out);
serialize(item.biases, out); serialize(item.biases, out);
} }
...@@ -151,57 +135,66 @@ namespace dlib ...@@ -151,57 +135,66 @@ namespace dlib
if (version != "con_") if (version != "con_")
throw serialization_error("Unexpected version found while deserializing dlib::con_."); throw serialization_error("Unexpected version found while deserializing dlib::con_.");
deserialize(item.params, in); deserialize(item.params, in);
deserialize(item._num_filters, in);
deserialize(item._nr, in);
deserialize(item._nc, in); long num_filters;
deserialize(item._stride_y, in); long nr;
deserialize(item._stride_x, in); long nc;
int stride_y;
int stride_x;
deserialize(num_filters, in);
deserialize(nr, in);
deserialize(nc, in);
deserialize(stride_y, in);
deserialize(stride_x, in);
deserialize(item.filters, in); deserialize(item.filters, in);
deserialize(item.biases, in); deserialize(item.biases, in);
if (num_filters != _num_filters) throw serialization_error("Wrong num_filters found while deserializing dlib::con_");
if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_");
if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
} }
private: private:
resizable_tensor params; resizable_tensor params;
long _num_filters;
long _nr;
long _nc;
int _stride_y;
int _stride_x;
alias_tensor filters, biases; alias_tensor filters, biases;
tt::tensor_conv conv; tt::tensor_conv conv;
}; };
template <typename SUBNET> template <
using con = add_layer<con_, SUBNET>; long num_filters,
long nr,
long nc,
int stride_y,
int stride_x,
typename SUBNET
>
using con = add_layer<con_<num_filters,nr,nc,stride_y,stride_x>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <
long _nr,
long _nc,
int _stride_y,
int _stride_x
>
class max_pool_ class max_pool_
{ {
static_assert(_nr > 0, "The number of rows in a filter must be > 0");
static_assert(_nc > 0, "The number of columns in a filter must be > 0");
static_assert(_stride_y > 0, "The filter stride must be > 0");
static_assert(_stride_x > 0, "The filter stride must be > 0");
public: public:
max_pool_ (
) :
_nr(3),
_nc(3),
_stride_y(1),
_stride_x(1)
{}
max_pool_( max_pool_(
long nr_, ) {}
long nc_,
int stride_y_ = 1,
int stride_x_ = 1
) :
_nr(nr_),
_nc(nc_),
_stride_y(stride_y_),
_stride_x(stride_x_)
{}
long nr() const { return _nr; } long nr() const { return _nr; }
long nc() const { return _nc; } long nc() const { return _nc; }
...@@ -209,12 +202,8 @@ namespace dlib ...@@ -209,12 +202,8 @@ namespace dlib
long stride_x() const { return _stride_x; } long stride_x() const { return _stride_x; }
max_pool_ ( max_pool_ (
const max_pool_& item const max_pool_&
) : )
_nr(item._nr),
_nc(item._nc),
_stride_y(item._stride_y),
_stride_x(item._stride_x)
{ {
// this->mp is non-copyable so we have to write our own copy to avoid trying to // this->mp is non-copyable so we have to write our own copy to avoid trying to
// copy it and getting an error. // copy it and getting an error.
...@@ -230,11 +219,6 @@ namespace dlib ...@@ -230,11 +219,6 @@ namespace dlib
// this->mp is non-copyable so we have to write our own copy to avoid trying to // this->mp is non-copyable so we have to write our own copy to avoid trying to
// copy it and getting an error. // copy it and getting an error.
_nr = item._nr;
_nc = item._nc;
_stride_y = item._stride_y;
_stride_x = item._stride_x;
mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x); mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x);
return *this; return *this;
} }
...@@ -263,10 +247,10 @@ namespace dlib ...@@ -263,10 +247,10 @@ namespace dlib
friend void serialize(const max_pool_& item, std::ostream& out) friend void serialize(const max_pool_& item, std::ostream& out)
{ {
serialize("max_pool_", out); serialize("max_pool_", out);
serialize(item._nr, out); serialize(_nr, out);
serialize(item._nc, out); serialize(_nc, out);
serialize(item._stride_y, out); serialize(_stride_y, out);
serialize(item._stride_x, out); serialize(_stride_x, out);
} }
friend void deserialize(max_pool_& item, std::istream& in) friend void deserialize(max_pool_& item, std::istream& in)
...@@ -275,53 +259,58 @@ namespace dlib ...@@ -275,53 +259,58 @@ namespace dlib
deserialize(version, in); deserialize(version, in);
if (version != "max_pool_") if (version != "max_pool_")
throw serialization_error("Unexpected version found while deserializing dlib::max_pool_."); throw serialization_error("Unexpected version found while deserializing dlib::max_pool_.");
deserialize(item._nr, in);
deserialize(item._nc, in);
deserialize(item._stride_y, in);
deserialize(item._stride_x, in);
item.mp.setup_max_pooling(item._nr, item._nc, item._stride_y, item._stride_x); item.mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x);
long nr;
long nc;
int stride_y;
int stride_x;
deserialize(nr, in);
deserialize(nc, in);
deserialize(stride_y, in);
deserialize(stride_x, in);
if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::max_pool_");
if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::max_pool_");
if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::max_pool_");
if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::max_pool_");
} }
private: private:
long _nr;
long _nc;
int _stride_y;
int _stride_x;
tt::pooling mp; tt::pooling mp;
resizable_tensor params; resizable_tensor params;
}; };
template <typename SUBNET> template <
using max_pool = add_layer<max_pool_, SUBNET>; long nr,
long nc,
int stride_y,
int stride_x,
typename SUBNET
>
using max_pool = add_layer<max_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <
long _nr,
long _nc,
int _stride_y,
int _stride_x
>
class avg_pool_ class avg_pool_
{ {
public: public:
static_assert(_nr > 0, "The number of rows in a filter must be > 0");
avg_pool_ ( static_assert(_nc > 0, "The number of columns in a filter must be > 0");
) : static_assert(_stride_y > 0, "The filter stride must be > 0");
_nr(3), static_assert(_stride_x > 0, "The filter stride must be > 0");
_nc(3),
_stride_y(1),
_stride_x(1)
{}
avg_pool_( avg_pool_(
long nr_, ) {}
long nc_,
int stride_y_ = 1,
int stride_x_ = 1
) :
_nr(nr_),
_nc(nc_),
_stride_y(stride_y_),
_stride_x(stride_x_)
{}
long nr() const { return _nr; } long nr() const { return _nr; }
long nc() const { return _nc; } long nc() const { return _nc; }
...@@ -329,12 +318,8 @@ namespace dlib ...@@ -329,12 +318,8 @@ namespace dlib
long stride_x() const { return _stride_x; } long stride_x() const { return _stride_x; }
avg_pool_ ( avg_pool_ (
const avg_pool_& item const avg_pool_&
) : )
_nr(item._nr),
_nc(item._nc),
_stride_y(item._stride_y),
_stride_x(item._stride_x)
{ {
// this->ap is non-copyable so we have to write our own copy to avoid trying to // this->ap is non-copyable so we have to write our own copy to avoid trying to
// copy it and getting an error. // copy it and getting an error.
...@@ -350,11 +335,6 @@ namespace dlib ...@@ -350,11 +335,6 @@ namespace dlib
// this->ap is non-copyable so we have to write our own copy to avoid trying to // this->ap is non-copyable so we have to write our own copy to avoid trying to
// copy it and getting an error. // copy it and getting an error.
_nr = item._nr;
_nc = item._nc;
_stride_y = item._stride_y;
_stride_x = item._stride_x;
ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x); ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x);
return *this; return *this;
} }
...@@ -383,10 +363,10 @@ namespace dlib ...@@ -383,10 +363,10 @@ namespace dlib
friend void serialize(const avg_pool_& item, std::ostream& out) friend void serialize(const avg_pool_& item, std::ostream& out)
{ {
serialize("avg_pool_", out); serialize("avg_pool_", out);
serialize(item._nr, out); serialize(_nr, out);
serialize(item._nc, out); serialize(_nc, out);
serialize(item._stride_y, out); serialize(_stride_y, out);
serialize(item._stride_x, out); serialize(_stride_x, out);
} }
friend void deserialize(avg_pool_& item, std::istream& in) friend void deserialize(avg_pool_& item, std::istream& in)
...@@ -395,27 +375,38 @@ namespace dlib ...@@ -395,27 +375,38 @@ namespace dlib
deserialize(version, in); deserialize(version, in);
if (version != "avg_pool_") if (version != "avg_pool_")
throw serialization_error("Unexpected version found while deserializing dlib::avg_pool_."); throw serialization_error("Unexpected version found while deserializing dlib::avg_pool_.");
deserialize(item._nr, in);
deserialize(item._nc, in);
deserialize(item._stride_y, in);
deserialize(item._stride_x, in);
item.ap.setup_avg_pooling(item._nr, item._nc, item._stride_y, item._stride_x); item.ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x);
long nr;
long nc;
int stride_y;
int stride_x;
deserialize(nr, in);
deserialize(nc, in);
deserialize(stride_y, in);
deserialize(stride_x, in);
if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::avg_pool_");
if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::avg_pool_");
if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::avg_pool_");
if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::avg_pool_");
} }
private: private:
long _nr;
long _nc;
int _stride_y;
int _stride_x;
tt::pooling ap; tt::pooling ap;
resizable_tensor params; resizable_tensor params;
}; };
template <typename SUBNET> template <
using avg_pool = add_layer<avg_pool_, SUBNET>; long nr,
long nc,
int stride_y,
int stride_x,
typename SUBNET
>
using avg_pool = add_layer<avg_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -425,16 +416,16 @@ namespace dlib ...@@ -425,16 +416,16 @@ namespace dlib
FC_MODE = 1 FC_MODE = 1
}; };
template <
layer_mode mode
>
class bn_ class bn_
{ {
public: public:
bn_() : num_updates(0), running_stats_window_size(1000), mode(FC_MODE) bn_() : num_updates(0), running_stats_window_size(1000)
{} {}
explicit bn_(layer_mode mode_) : num_updates(0), running_stats_window_size(1000), mode(mode_) explicit bn_(unsigned long window_size) : num_updates(0), running_stats_window_size(window_size)
{}
bn_(layer_mode mode_, unsigned long window_size) : num_updates(0), running_stats_window_size(window_size), mode(mode_)
{} {}
layer_mode get_mode() const { return mode; } layer_mode get_mode() const { return mode; }
...@@ -519,7 +510,7 @@ namespace dlib ...@@ -519,7 +510,7 @@ namespace dlib
serialize(item.running_invstds, out); serialize(item.running_invstds, out);
serialize(item.num_updates, out); serialize(item.num_updates, out);
serialize(item.running_stats_window_size, out); serialize(item.running_stats_window_size, out);
serialize((int)item.mode, out); serialize((int)mode, out);
} }
friend void deserialize(bn_& item, std::istream& in) friend void deserialize(bn_& item, std::istream& in)
...@@ -537,13 +528,14 @@ namespace dlib ...@@ -537,13 +528,14 @@ namespace dlib
deserialize(item.running_invstds, in); deserialize(item.running_invstds, in);
deserialize(item.num_updates, in); deserialize(item.num_updates, in);
deserialize(item.running_stats_window_size, in); deserialize(item.running_stats_window_size, in);
int mode; int _mode;
deserialize(mode, in); deserialize(_mode, in);
item.mode = (layer_mode)mode; if (mode != (layer_mode)_mode) throw serialization_error("Wrong mode found while deserializing dlib::bn_");
} }
private: private:
template < layer_mode Mode >
friend class affine_; friend class affine_;
resizable_tensor params; resizable_tensor params;
...@@ -552,32 +544,41 @@ namespace dlib ...@@ -552,32 +544,41 @@ namespace dlib
resizable_tensor invstds, running_invstds; resizable_tensor invstds, running_invstds;
unsigned long num_updates; unsigned long num_updates;
unsigned long running_stats_window_size; unsigned long running_stats_window_size;
layer_mode mode;
}; };
template <typename SUBNET> template <typename SUBNET>
using bn = add_layer<bn_, SUBNET>; using bn_con = add_layer<bn_<CONV_MODE>, SUBNET>;
template <typename SUBNET>
using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
enum fc_bias_mode{ enum fc_bias_mode
{
FC_HAS_BIAS = 0, FC_HAS_BIAS = 0,
FC_NO_BIAS = 1 FC_NO_BIAS = 1
}; };
struct num_fc_outputs
{
num_fc_outputs(unsigned long n) : num_outputs(n) {}
unsigned long num_outputs;
};
template <
unsigned long num_outputs_,
fc_bias_mode bias_mode
>
class fc_ class fc_
{ {
static_assert(num_outputs_ > 0, "The number of outputs from a fc_ layer must be > 0");
public: public:
fc_() : num_outputs(1), num_inputs(0), bias_mode(FC_HAS_BIAS) fc_() : num_outputs(num_outputs_), num_inputs(0)
{ {
} }
explicit fc_( fc_(num_fc_outputs o) : num_outputs(o.num_outputs), num_inputs(0) {}
unsigned long num_outputs_,
fc_bias_mode mode = FC_HAS_BIAS
) : num_outputs(num_outputs_), num_inputs(0), bias_mode(mode)
{
}
unsigned long get_num_outputs ( unsigned long get_num_outputs (
) const { return num_outputs; } ) const { return num_outputs; }
...@@ -651,7 +652,7 @@ namespace dlib ...@@ -651,7 +652,7 @@ namespace dlib
serialize(item.params, out); serialize(item.params, out);
serialize(item.weights, out); serialize(item.weights, out);
serialize(item.biases, out); serialize(item.biases, out);
serialize((int)item.bias_mode, out); serialize((int)bias_mode, out);
} }
friend void deserialize(fc_& item, std::istream& in) friend void deserialize(fc_& item, std::istream& in)
...@@ -660,6 +661,7 @@ namespace dlib ...@@ -660,6 +661,7 @@ namespace dlib
deserialize(version, in); deserialize(version, in);
if (version != "fc_") if (version != "fc_")
throw serialization_error("Unexpected version found while deserializing dlib::fc_."); throw serialization_error("Unexpected version found while deserializing dlib::fc_.");
deserialize(item.num_outputs, in); deserialize(item.num_outputs, in);
deserialize(item.num_inputs, in); deserialize(item.num_inputs, in);
deserialize(item.params, in); deserialize(item.params, in);
...@@ -667,7 +669,7 @@ namespace dlib ...@@ -667,7 +669,7 @@ namespace dlib
deserialize(item.biases, in); deserialize(item.biases, in);
int bmode = 0; int bmode = 0;
deserialize(bmode, in); deserialize(bmode, in);
item.bias_mode = (fc_bias_mode)bmode; if (bias_mode != (fc_bias_mode)bmode) throw serialization_error("Wrong fc_bias_mode found while deserializing dlib::fc_");
} }
private: private:
...@@ -676,11 +678,14 @@ namespace dlib ...@@ -676,11 +678,14 @@ namespace dlib
unsigned long num_inputs; unsigned long num_inputs;
resizable_tensor params; resizable_tensor params;
alias_tensor weights, biases; alias_tensor weights, biases;
fc_bias_mode bias_mode;
}; };
template <typename SUBNET> template <
using fc = add_layer<fc_, SUBNET>; unsigned long num_outputs,
fc_bias_mode bias_mode,
typename SUBNET
>
using fc = add_layer<fc_<num_outputs,bias_mode>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -849,27 +854,22 @@ namespace dlib ...@@ -849,27 +854,22 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <
layer_mode mode
>
class affine_ class affine_
{ {
public: public:
affine_( affine_(
) : mode(FC_MODE) )
{ {}
}
explicit affine_(
layer_mode mode_
) : mode(mode_)
{
}
affine_( affine_(
const bn_& item const bn_<mode>& item
) )
{ {
gamma = item.gamma; gamma = item.gamma;
beta = item.beta; beta = item.beta;
mode = item.mode;
params.copy_size(item.params); params.copy_size(item.params);
...@@ -959,7 +959,7 @@ namespace dlib ...@@ -959,7 +959,7 @@ namespace dlib
// Since we can build an affine_ from a bn_ we check if that's what is in // Since we can build an affine_ from a bn_ we check if that's what is in
// the stream and if so then just convert it right here. // the stream and if so then just convert it right here.
unserialize sin(version, in); unserialize sin(version, in);
bn_ temp; bn_<mode> temp;
deserialize(temp, sin); deserialize(temp, sin);
item = temp; item = temp;
return; return;
...@@ -970,19 +970,20 @@ namespace dlib ...@@ -970,19 +970,20 @@ namespace dlib
deserialize(item.params, in); deserialize(item.params, in);
deserialize(item.gamma, in); deserialize(item.gamma, in);
deserialize(item.beta, in); deserialize(item.beta, in);
int mode; int _mode;
deserialize(mode, in); deserialize(_mode, in);
item.mode = (layer_mode)mode; if (mode != (layer_mode)_mode) throw serialization_error("Wrong mode found while deserializing dlib::affine_");
} }
private: private:
resizable_tensor params, empty_params; resizable_tensor params, empty_params;
alias_tensor gamma, beta; alias_tensor gamma, beta;
layer_mode mode;
}; };
template <typename SUBNET> template <typename SUBNET>
using affine = add_layer<affine_, SUBNET>; using affine_con = add_layer<affine_<CONV_MODE>, SUBNET>;
template <typename SUBNET>
using affine_fc = add_layer<affine_<FC_MODE>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -1129,6 +1130,9 @@ namespace dlib ...@@ -1129,6 +1130,9 @@ namespace dlib
{ {
} }
float get_initial_param_value (
) const { return initial_param_value; }
template <typename SUBNET> template <typename SUBNET>
void setup (const SUBNET& /*sub*/) void setup (const SUBNET& /*sub*/)
{ {
......
...@@ -322,14 +322,28 @@ namespace dlib ...@@ -322,14 +322,28 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
enum fc_bias_mode{ enum fc_bias_mode
{
FC_HAS_BIAS = 0, FC_HAS_BIAS = 0,
FC_NO_BIAS = 1 FC_NO_BIAS = 1
}; };
struct num_fc_outputs
{
num_fc_outputs(unsigned long n) : num_outputs(n) {}
unsigned long num_outputs;
};
template <
unsigned long num_outputs,
fc_bias_mode bias_mode
>
class fc_ class fc_
{ {
/*! /*!
REQUIREMENTS ON num_outputs
num_outputs > 0
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_LAYER_ interface defined above. This is an implementation of the EXAMPLE_LAYER_ interface defined above.
In particular, it defines a fully connected layer that takes an input In particular, it defines a fully connected layer that takes an input
...@@ -337,24 +351,13 @@ namespace dlib ...@@ -337,24 +351,13 @@ namespace dlib
!*/ !*/
public: public:
fc_(
);
/*!
ensures
- #get_num_outputs() == 1
- #get_bias_mode() == FC_HAS_BIAS
!*/
explicit fc_( fc_(
unsigned long num_outputs,
fc_bias_mode mode = FC_HAS_BIAS
); );
/*! /*!
requires
- num_outputs > 0
ensures ensures
- #get_num_outputs() == num_outputs - #get_num_outputs() == num_outputs
- #get_bias_mode() == mode - #get_bias_mode() == bias_mode
!*/ !*/
unsigned long get_num_outputs ( unsigned long get_num_outputs (
...@@ -385,22 +388,37 @@ namespace dlib ...@@ -385,22 +388,37 @@ namespace dlib
/*! /*!
These functions are implemented as described in the EXAMPLE_LAYER_ interface. These functions are implemented as described in the EXAMPLE_LAYER_ interface.
!*/ !*/
};
void serialize(const fc_& item, std::ostream& out); friend void serialize(const fc_& item, std::ostream& out);
void deserialize(fc_& item, std::istream& in); friend void deserialize(fc_& item, std::istream& in);
/*! /*!
provides serialization support provides serialization support
!*/ !*/
};
template <typename SUBNET>
using fc = add_layer<fc_, SUBNET>; template <
unsigned long num_outputs,
fc_bias_mode bias_mode,
typename SUBNET
>
using fc = add_layer<fc_<num_outputs,bias_mode>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <
long _num_filters,
long _nr,
long _nc,
int _stride_y,
int _stride_x
>
class con_ class con_
{ {
/*! /*!
REQUIREMENTS ON TEMPLATE ARGUMENTS
All of them must be > 0.
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_LAYER_ interface defined above. This is an implementation of the EXAMPLE_LAYER_ interface defined above.
In particular, it defines a convolution layer that takes an input tensor In particular, it defines a convolution layer that takes an input tensor
...@@ -420,33 +438,11 @@ namespace dlib ...@@ -420,33 +438,11 @@ namespace dlib
); );
/*! /*!
ensures ensures
- #num_filters() == 1 - #num_filters() == _num_filters
- #nr() == 3 - #nr() == _nr
- #nc() == 3 - #nc() == _nc
- #stride_y() == 1 - #stride_y() == _stride_y
- #stride_x() == 1 - #stride_x() == _stride_x
!*/
con_(
long num_filters_,
long nr_,
long nc_,
int stride_y_ = 1,
int stride_x_ = 1
);
/*!
requires
- num_filters_ > 0
- nr_ > 0
- nc_ > 0
- stride_y_ > 0
- stride_x_ > 0
ensures
- #num_filters() == num_filters_
- #nr() == nr_
- #nc() == nc_
- #stride_y() == stride_y_
- #stride_x() == stride_x_
!*/ !*/
long num_filters( long num_filters(
...@@ -498,16 +494,24 @@ namespace dlib ...@@ -498,16 +494,24 @@ namespace dlib
/*! /*!
These functions are implemented as described in the EXAMPLE_LAYER_ interface. These functions are implemented as described in the EXAMPLE_LAYER_ interface.
!*/ !*/
};
void serialize(const con_& item, std::ostream& out); friend void serialize(const con_& item, std::ostream& out);
void deserialize(con_& item, std::istream& in); friend void deserialize(con_& item, std::istream& in);
/*! /*!
provides serialization support provides serialization support
!*/ !*/
template <typename SUBNET> };
using con = add_layer<con_, SUBNET>;
template <
long num_filters,
long nr,
long nc,
int stride_y,
int stride_x,
typename SUBNET
>
using con = add_layer<con_<num_filters,nr,nc,stride_y,stride_x>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -631,6 +635,9 @@ namespace dlib ...@@ -631,6 +635,9 @@ namespace dlib
FC_MODE = 1 // fully connected mode FC_MODE = 1 // fully connected mode
}; };
template <
layer_mode mode
>
class bn_ class bn_
{ {
/*! /*!
...@@ -663,17 +670,17 @@ namespace dlib ...@@ -663,17 +670,17 @@ namespace dlib
); );
/*! /*!
ensures ensures
- #get_mode() == FC_MODE - #get_mode() == mode
- get_running_stats_window_size() == 1000 - get_running_stats_window_size() == 1000
!*/ !*/
explicit bn_( explicit bn_(
layer_mode mode unsigned long window_size
); );
/*! /*!
ensures ensures
- #get_mode() == mode - #get_mode() == mode
- get_running_stats_window_size() == 1000 - get_running_stats_window_size() == window_size
!*/ !*/
layer_mode get_mode( layer_mode get_mode(
...@@ -713,19 +720,25 @@ namespace dlib ...@@ -713,19 +720,25 @@ namespace dlib
/*! /*!
These functions are implemented as described in the EXAMPLE_LAYER_ interface. These functions are implemented as described in the EXAMPLE_LAYER_ interface.
!*/ !*/
};
void serialize(const bn_& item, std::ostream& out); friend void serialize(const bn_& item, std::ostream& out);
void deserialize(bn_& item, std::istream& in); friend void deserialize(bn_& item, std::istream& in);
/*! /*!
provides serialization support provides serialization support
!*/ !*/
};
template <typename SUBNET>
using bn_con = add_layer<bn_<CONV_MODE>, SUBNET>;
template <typename SUBNET> template <typename SUBNET>
using bn = add_layer<bn_, SUBNET>; using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <
layer_mode mode
>
class affine_ class affine_
{ {
/*! /*!
...@@ -766,11 +779,11 @@ namespace dlib ...@@ -766,11 +779,11 @@ namespace dlib
); );
/*! /*!
ensures ensures
- #get_mode() == FC_MODE - #get_mode() == mode
!*/ !*/
affine_( affine_(
const bn_& layer const bn_<mode>& layer
); );
/*! /*!
ensures ensures
...@@ -781,14 +794,6 @@ namespace dlib ...@@ -781,14 +794,6 @@ namespace dlib
- #get_mode() == layer.get_mode() - #get_mode() == layer.get_mode()
!*/ !*/
explicit affine_(
layer_mode mode
);
/*!
ensures
- #get_mode() == mode
!*/
layer_mode get_mode( layer_mode get_mode(
) const; ) const;
/*! /*!
...@@ -806,22 +811,33 @@ namespace dlib ...@@ -806,22 +811,33 @@ namespace dlib
Also note that get_layer_params() always returns an empty tensor since there Also note that get_layer_params() always returns an empty tensor since there
are no learnable parameters in this object. are no learnable parameters in this object.
!*/ !*/
};
void serialize(const affine_& item, std::ostream& out); friend void serialize(const affine_& item, std::ostream& out);
void deserialize(affine_& item, std::istream& in); friend void deserialize(affine_& item, std::istream& in);
/*! /*!
provides serialization support provides serialization support
!*/ !*/
};
template <typename SUBNET> template <typename SUBNET>
using affine = add_layer<affine_, SUBNET>; using affine_con = add_layer<affine_<CONV_MODE>, SUBNET>;
template <typename SUBNET>
using affine_fc = add_layer<affine_<FC_MODE>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <
long _nr,
long _nc,
int _stride_y,
int _stride_x
>
class max_pool_ class max_pool_
{ {
/*! /*!
REQUIREMENTS ON TEMPLATE ARGUMENTS
All of them must be > 0.
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_LAYER_ interface defined above. This is an implementation of the EXAMPLE_LAYER_ interface defined above.
In particular, it defines a max pooling layer that takes an input tensor In particular, it defines a max pooling layer that takes an input tensor
...@@ -849,24 +865,10 @@ namespace dlib ...@@ -849,24 +865,10 @@ namespace dlib
); );
/*! /*!
ensures ensures
- #nr() == 3 - #nr() == _nr
- #nc() == 3 - #nc() == _nc
- #stride_y() == 1 - #stride_y() == _stride_y
- #stride_x() == 1 - #stride_x() == _stride_x
!*/
max_pool_(
long nr_,
long nc_,
int stride_y_ = 1,
int stride_x_ = 1
);
/*!
ensures
- #nr() == nr_
- #nc() == nc_
- #stride_y() == stride_y_
- #stride_x() == stride_x_
!*/ !*/
long nr( long nr(
...@@ -911,22 +913,37 @@ namespace dlib ...@@ -911,22 +913,37 @@ namespace dlib
Note that this layer doesn't have any parameters, so the tensor returned by Note that this layer doesn't have any parameters, so the tensor returned by
get_layer_params() is always empty. get_layer_params() is always empty.
!*/ !*/
};
void serialize(const max_pool_& item, std::ostream& out); friend void serialize(const max_pool_& item, std::ostream& out);
void deserialize(max_pool_& item, std::istream& in); friend void deserialize(max_pool_& item, std::istream& in);
/*! /*!
provides serialization support provides serialization support
!*/ !*/
};
template <typename SUBNET> template <
using max_pool = add_layer<max_pool_, SUBNET>; long nr,
long nc,
int stride_y,
int stride_x,
typename SUBNET
>
using max_pool = add_layer<max_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <
long _nr,
long _nc,
int _stride_y,
int _stride_x
>
class avg_pool_ class avg_pool_
{ {
/*! /*!
REQUIREMENTS ON TEMPLATE ARGUMENTS
All of them must be > 0.
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_LAYER_ interface defined above. This is an implementation of the EXAMPLE_LAYER_ interface defined above.
In particular, it defines an average pooling layer that takes an input tensor In particular, it defines an average pooling layer that takes an input tensor
...@@ -954,24 +971,10 @@ namespace dlib ...@@ -954,24 +971,10 @@ namespace dlib
); );
/*! /*!
ensures ensures
- #nr() == 3 - #nr() == _nr
- #nc() == 3 - #nc() == _nc
- #stride_y() == 1 - #stride_y() == _stride_y
- #stride_x() == 1 - #stride_x() == _stride_x
!*/
avg_pool_(
long nr_,
long nc_,
int stride_y_ = 1,
int stride_x_ = 1
);
/*!
ensures
- #nr() == nr_
- #nc() == nc_
- #stride_y() == stride_y_
- #stride_x() == stride_x_
!*/ !*/
long nr( long nr(
...@@ -1016,16 +1019,22 @@ namespace dlib ...@@ -1016,16 +1019,22 @@ namespace dlib
Note that this layer doesn't have any parameters, so the tensor returned by Note that this layer doesn't have any parameters, so the tensor returned by
get_layer_params() is always empty. get_layer_params() is always empty.
!*/ !*/
};
void serialize(const avg_pool_& item, std::ostream& out); friend void serialize(const avg_pool_& item, std::ostream& out);
void deserialize(avg_pool_& item, std::istream& in); friend void deserialize(avg_pool_& item, std::istream& in);
/*! /*!
provides serialization support provides serialization support
!*/ !*/
};
template <typename SUBNET> template <
using avg_pool = add_layer<avg_pool_, SUBNET>; long nr,
long nc,
int stride_y,
int stride_x,
typename SUBNET
>
using avg_pool = add_layer<avg_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -1094,6 +1103,14 @@ namespace dlib ...@@ -1094,6 +1103,14 @@ namespace dlib
/*! /*!
ensures ensures
- The p parameter will be initialized with initial_param_value. - The p parameter will be initialized with initial_param_value.
- #get_initial_param_value() == initial_param_value.
!*/
float get_initial_param_value (
) const;
/*!
ensures
- returns the initial value of the prelu parameter.
!*/ !*/
template <typename SUBNET> void setup (const SUBNET& sub); template <typename SUBNET> void setup (const SUBNET& sub);
......
...@@ -1076,67 +1076,67 @@ namespace ...@@ -1076,67 +1076,67 @@ namespace
} }
{ {
print_spinner(); print_spinner();
max_pool_ l; max_pool_<3,3,1,1> l;
DLIB_TEST_MSG(test_layer(l), test_layer(l)); DLIB_TEST_MSG(test_layer(l), test_layer(l));
} }
{ {
print_spinner(); print_spinner();
avg_pool_ l; avg_pool_<3,3,1,1> l;
DLIB_TEST_MSG(test_layer(l), test_layer(l)); DLIB_TEST_MSG(test_layer(l), test_layer(l));
} }
{ {
print_spinner(); print_spinner();
affine_ l(CONV_MODE); affine_<CONV_MODE> l;
DLIB_TEST_MSG(test_layer(l), test_layer(l)); DLIB_TEST_MSG(test_layer(l), test_layer(l));
} }
{ {
print_spinner(); print_spinner();
affine_ l(FC_MODE); affine_<FC_MODE> l;
DLIB_TEST_MSG(test_layer(l), test_layer(l)); DLIB_TEST_MSG(test_layer(l), test_layer(l));
} }
{ {
print_spinner(); print_spinner();
bn_ l(CONV_MODE); bn_<CONV_MODE> l;
DLIB_TEST_MSG(test_layer(l), test_layer(l)); DLIB_TEST_MSG(test_layer(l), test_layer(l));
} }
{ {
print_spinner(); print_spinner();
bn_ l(FC_MODE); bn_<FC_MODE> l;
DLIB_TEST_MSG(test_layer(l), test_layer(l)); DLIB_TEST_MSG(test_layer(l), test_layer(l));
} }
{ {
print_spinner(); print_spinner();
con_ l(3,3,3,2,2); con_<3,3,3,2,2> l;
DLIB_TEST_MSG(test_layer(l), test_layer(l)); DLIB_TEST_MSG(test_layer(l), test_layer(l));
} }
{ {
print_spinner(); print_spinner();
con_ l(3,3,3,1,1); con_<3,3,3,1,1>l;
DLIB_TEST_MSG(test_layer(l), test_layer(l)); DLIB_TEST_MSG(test_layer(l), test_layer(l));
} }
{ {
print_spinner(); print_spinner();
con_ l(3,3,2,1,1); con_<3,3,2,1,1> l;
DLIB_TEST_MSG(test_layer(l), test_layer(l)); DLIB_TEST_MSG(test_layer(l), test_layer(l));
} }
{ {
print_spinner(); print_spinner();
con_ l(2,1,1,1,1); con_<2,1,1,1,1> l;
DLIB_TEST_MSG(test_layer(l), test_layer(l)); DLIB_TEST_MSG(test_layer(l), test_layer(l));
} }
{ {
print_spinner(); print_spinner();
fc_ l; fc_<1,FC_HAS_BIAS> l;
DLIB_TEST_MSG(test_layer(l), test_layer(l)); DLIB_TEST_MSG(test_layer(l), test_layer(l));
} }
{ {
print_spinner(); print_spinner();
fc_ l(5,FC_HAS_BIAS); fc_<5,FC_HAS_BIAS> l;
DLIB_TEST_MSG(test_layer(l), test_layer(l)); DLIB_TEST_MSG(test_layer(l), test_layer(l));
} }
{ {
print_spinner(); print_spinner();
fc_ l(5,FC_NO_BIAS); fc_<5,FC_NO_BIAS> l;
DLIB_TEST_MSG(test_layer(l), test_layer(l)); DLIB_TEST_MSG(test_layer(l), test_layer(l));
} }
{ {
...@@ -1168,29 +1168,16 @@ namespace ...@@ -1168,29 +1168,16 @@ namespace
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename T> using rcon = max_pool<relu<bn<con<T>>>>; template <unsigned long n, typename SUBNET> using rcon = max_pool<2,2,2,2,relu<bn_con<con<n,5,5,1,1,SUBNET>>>>;
std::tuple<max_pool_,relu_,bn_,con_> rcon_ (unsigned long n) template <unsigned long n, typename SUBNET> using rfc = relu<bn_fc<fc<n,FC_HAS_BIAS,SUBNET>>>;
{
return std::make_tuple(max_pool_(2,2,2,2),relu_(),bn_(CONV_MODE),con_(n,5,5));
}
template <typename T> using rfc = relu<bn<fc<T>>>;
std::tuple<relu_,bn_,fc_> rfc_ (unsigned long n)
{
return std::make_tuple(relu_(),bn_(),fc_(n));
}
void test_tagging( void test_tagging(
) )
{ {
typedef loss_multiclass_log<rfc<skip1<rfc<rfc<tag1<rcon<rcon<input<matrix<unsigned char>>>>>>>>>> net_type; typedef loss_multiclass_log<rfc<10,skip1<rfc<84,rfc<120,tag1<rcon<16,rcon<6,input<matrix<unsigned char>>>>>>>>>> net_type;
net_type net(rfc_(10), net_type net;
rfc_(84), net_type net2(num_fc_outputs(4));
rfc_(120),
rcon_(16),
rcon_(6)
);
DLIB_TEST(layer<tag1>(net).num_layers == 8); DLIB_TEST(layer<tag1>(net).num_layers == 8);
DLIB_TEST(layer<skip1>(net).num_layers == 8+3+3); DLIB_TEST(layer<skip1>(net).num_layers == 8+3+3);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
using namespace std; using namespace std;
using namespace dlib; using namespace dlib;
int main(int argc, char** argv) try int main(int argc, char** argv) try
{ {
if (argc != 2) if (argc != 2)
...@@ -23,6 +24,8 @@ int main(int argc, char** argv) try ...@@ -23,6 +24,8 @@ int main(int argc, char** argv) try
return 1; return 1;
} }
std::vector<matrix<unsigned char>> training_images; std::vector<matrix<unsigned char>> training_images;
std::vector<unsigned long> training_labels; std::vector<unsigned long> training_labels;
std::vector<matrix<unsigned char>> testing_images; std::vector<matrix<unsigned char>> testing_images;
...@@ -30,22 +33,18 @@ int main(int argc, char** argv) try ...@@ -30,22 +33,18 @@ int main(int argc, char** argv) try
load_mnist_dataset(argv[1], training_images, training_labels, testing_images, testing_labels); load_mnist_dataset(argv[1], training_images, training_labels, testing_images, testing_labels);
typedef loss_multiclass_log<fc<relu<fc<relu<fc<max_pool<relu<con<max_pool<relu<con< using net_type = loss_multiclass_log<
input<matrix<unsigned char>>>>>>>>>>>>>> net_type; fc<10,FC_HAS_BIAS,
relu<fc<84,FC_HAS_BIAS,
relu<fc<120,FC_HAS_BIAS,
max_pool<2,2,2,2,relu<con<16,5,5,1,1,
max_pool<2,2,2,2,relu<con<6,5,5,1,1,
input<matrix<unsigned char>>>>>>>>>>>>>>;
net_type net(fc_(10), net_type net;
relu_(),
fc_(84),
relu_(),
fc_(120),
max_pool_(2,2,2,2),
relu_(),
con_(16,5,5),
max_pool_(2,2,2,2),
relu_(),
con_(6,5,5));
dnn_trainer<net_type> trainer(net,sgd(0.1)); dnn_trainer<net_type> trainer(net,sgd(0.01));
trainer.set_mini_batch_size(128); trainer.set_mini_batch_size(128);
trainer.be_verbose(); trainer.be_verbose();
trainer.set_synchronization_file("mnist_sync", std::chrono::seconds(20)); trainer.set_synchronization_file("mnist_sync", std::chrono::seconds(20));
......
...@@ -9,23 +9,19 @@ using namespace dlib; ...@@ -9,23 +9,19 @@ using namespace dlib;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename T> using res = relu<add_prev1<bn<con<relu<bn<con<tag1<T>>>>>>>>; template <int stride, typename SUBNET>
using base_res = relu<add_prev1< bn_con<con<8,3,3,1,1,relu< bn_con<con<8,3,3,stride,stride,tag1<SUBNET>>>>>>>>;
std::tuple<relu_,add_prev1_,bn_,con_,relu_,bn_,con_> res_ ( template <int stride, typename SUBNET>
unsigned long outputs, using base_ares = relu<add_prev1<affine_con<con<8,3,3,1,1,relu<affine_con<con<8,3,3,stride,stride,tag1<SUBNET>>>>>>>>;
unsigned long stride = 1
)
{
return std::make_tuple(relu_(),
add_prev1_(),
bn_(CONV_MODE),
con_(outputs,3,3,stride,stride),
relu_(),
bn_(CONV_MODE),
con_(outputs,3,3,stride,stride));
}
template <typename T> using ares = relu<add_prev1<affine<con<relu<affine<con<tag1<T>>>>>>>>; template <typename SUBNET> using res = base_res<1,SUBNET>;
template <typename SUBNET> using res_down = base_res<2,SUBNET>;
template <typename SUBNET> using ares = base_ares<1,SUBNET>;
template <typename SUBNET> using ares_down = base_ares<2,SUBNET>;
template <typename SUBNET>
using pres = prelu<add_prev1< bn_con<con<8,3,3,1,1,prelu< bn_con<con<8,3,3,1,1,tag1<SUBNET>>>>>>>>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -44,24 +40,78 @@ int main(int argc, char** argv) try ...@@ -44,24 +40,78 @@ int main(int argc, char** argv) try
load_mnist_dataset(argv[1], training_images, training_labels, testing_images, testing_labels); load_mnist_dataset(argv[1], training_images, training_labels, testing_images, testing_labels);
set_dnn_prefer_smallest_algorithms();
typedef loss_multiclass_log<fc<avg_pool< const unsigned long number_of_classes = 10;
res<res<res<res< typedef loss_multiclass_log<fc<number_of_classes,FC_HAS_BIAS,
repeat<10,res, avg_pool<11,11,11,11,
res< res<res<res<res_down<
repeat<9,res, // repeat this layer 9 times
res_down<
res< res<
input<matrix<unsigned char> input<matrix<unsigned char>
>>>>>>>>>>> net_type; >>>>>>>>>>> net_type;
const unsigned long number_of_classes = 10; net_type net;
net_type net(fc_(number_of_classes),
avg_pool_(10,10,10,10),
res_(8),res_(8),res_(8),res_(8,2), // If you wanted to use the same network but override the number of outputs at runtime
res_(8), // repeated 10 times // you can do so like this:
res_(8,2), net_type net2(num_fc_outputs(15));
res_(8)
// Let's imagine we wanted to replace some of the relu layers with prelu layers. We
// might do it like this:
typedef loss_multiclass_log<fc<number_of_classes,FC_HAS_BIAS,
avg_pool<11,11,11,11,
pres<res<res<res_down< // 2 prelu layers here
tag4<repeat<9,pres, // 9 groups, each containing 2 prelu layers
res_down<
res<
input<matrix<unsigned char>
>>>>>>>>>>>> net_type2;
// prelu layers have a floating point parameter. If you want to set it to something
// other than its default value you can do so like this:
net_type2 pnet(prelu_(0.2),
prelu_(0.2),
repeat_group(prelu_(0.3),prelu_(0.4)) // Initialize all the prelu instances in the repeat
// layer. repeat_group() is needed to group the things
// that are part of repeat's block.
); );
// As you can see, a network will greedily assign things given to its constructor to
// the layers inside itself. The assignment is done in the order the layers are
// defined but it will skip layers where the assignment doesn't make sense.
// You can access sub layers of the network like this:
net.subnet().subnet().get_output();
layer<2>(net).get_output();
layer<relu>(net).get_output();
layer<tag1>(net).get_output();
// To further illustrate the use of layer(), let's loop over the repeated layers and
// print out their parameters. But first, let's grab a reference to the repeat layer.
// Since we tagged the repeat layer we can access it using the layer() method.
// layer<tag4>(pnet) returns the tag4 layer, but we want the repeat layer so we can
// give an integer as the second argument and it will jump that many layers down the
// network. In our case we need to jump just 1 layer down to get to repeat.
auto&& repeat_layer = layer<tag4,1>(pnet);
for (size_t i = 0; i < repeat_layer.num_repetitions(); ++i)
{
// The repeat layer just instantiates the network block a bunch of times as a
// network object. get_repeated_layer() allows us to grab each of these instances.
auto&& repeated_layer = repeat_layer.get_repeated_layer(i);
// Now that we have the i-th layer inside our repeat layer we can look at its
// properties. Recall that we repeated the "pres" network block, which is itself a
// network with a bunch of layers. So we can again use layer() to jump to the
// prelu layers we are interested in like so:
prelu_ prelu1 = layer<prelu>(repeated_layer).layer_details();
prelu_ prelu2 = layer<prelu>(repeated_layer.subnet()).layer_details();
cout << "first prelu layer parameter value: "<< prelu1.get_initial_param_value() << endl;;
cout << "second prelu layer parameter value: "<< prelu2.get_initial_param_value() << endl;;
}
dnn_trainer<net_type,adam> trainer(net,adam(0.001)); dnn_trainer<net_type,adam> trainer(net,adam(0.001));
...@@ -89,20 +139,16 @@ int main(int argc, char** argv) try ...@@ -89,20 +139,16 @@ int main(int argc, char** argv) try
// wait for threaded processing to stop. // wait for threaded processing to stop.
trainer.get_net(); trainer.get_net();
// You can access sub layers of the network like this:
net.subnet().subnet().get_output();
layer<2>(net).get_output();
layer<avg_pool>(net).get_output();
net.clean(); net.clean();
serialize("mnist_res_network.dat") << net; serialize("mnist_res_network.dat") << net;
typedef loss_multiclass_log<fc<avg_pool< typedef loss_multiclass_log<fc<number_of_classes,FC_HAS_BIAS,
ares<ares<ares<ares< avg_pool<11,11,11,11,
repeat<10,ares, ares<ares<ares<ares_down<
ares< repeat<9,res,
ares_down<
ares< ares<
input<matrix<unsigned char> input<matrix<unsigned char>
>>>>>>>>>>> test_net_type; >>>>>>>>>>> test_net_type;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment