Commit 81847660 authored by Davis King's avatar Davis King

Added conv_ spec and did a little cleanup.

parent adc022c7
...@@ -21,6 +21,16 @@ namespace dlib ...@@ -21,6 +21,16 @@ namespace dlib
class con_ class con_
{ {
public: public:
con_ (
) :
_num_filters(1),
_nr(3),
_nc(3),
_stride_y(1),
_stride_x(1)
{}
con_( con_(
long num_filters_, long num_filters_,
long nr_, long nr_,
...@@ -28,22 +38,28 @@ namespace dlib ...@@ -28,22 +38,28 @@ namespace dlib
int stride_y_ = 1, int stride_y_ = 1,
int stride_x_ = 1 int stride_x_ = 1
) : ) :
num_filters(num_filters_), _num_filters(num_filters_),
nr(nr_), _nr(nr_),
nc(nc_), _nc(nc_),
stride_y(stride_y_), _stride_y(stride_y_),
stride_x(stride_x_) _stride_x(stride_x_)
{} {}
long num_filters() const { return _num_filters; }
long nr() const { return _nr; }
long nc() const { return _nc; }
long stride_y() const { return _stride_y; }
long stride_x() const { return _stride_x; }
con_ ( con_ (
const con_& item const con_& item
) : ) :
params(item.params), params(item.params),
num_filters(item.num_filters), _num_filters(item._num_filters),
nr(item.nr), _nr(item._nr),
nc(item.nc), _nc(item._nc),
stride_y(item.stride_y), _stride_y(item._stride_y),
stride_x(item.stride_x), _stride_x(item._stride_x),
filters(item.filters), filters(item.filters),
biases(item.biases) biases(item.biases)
{ {
...@@ -61,11 +77,11 @@ namespace dlib ...@@ -61,11 +77,11 @@ namespace dlib
// this->conv is non-copyable and basically stateless, so we have to write our // this->conv is non-copyable and basically stateless, so we have to write our
// own copy to avoid trying to copy it and getting an error. // own copy to avoid trying to copy it and getting an error.
params = item.params; params = item.params;
num_filters = item.num_filters; _num_filters = item._num_filters;
nr = item.nr; _nr = item._nr;
nc = item.nc; _nc = item._nc;
stride_y = item.stride_y; _stride_y = item._stride_y;
stride_x = item.stride_x; _stride_x = item._stride_x;
filters = item.filters; filters = item.filters;
biases = item.biases; biases = item.biases;
return *this; return *this;
...@@ -74,16 +90,16 @@ namespace dlib ...@@ -74,16 +90,16 @@ namespace dlib
template <typename SUBNET> template <typename SUBNET>
void setup (const SUBNET& sub) void setup (const SUBNET& sub)
{ {
long num_inputs = nr*nc*sub.get_output().k(); long num_inputs = _nr*_nc*sub.get_output().k();
long num_outputs = num_filters; long num_outputs = _num_filters;
// allocate params for the filters and also for the filter bias values. // allocate params for the filters and also for the filter bias values.
params.set_size(num_inputs*num_filters + num_filters); params.set_size(num_inputs*_num_filters + _num_filters);
dlib::rand rnd("con_"+cast_to_string(num_outputs+num_inputs)); dlib::rand rnd("con_"+cast_to_string(num_outputs+num_inputs));
randomize_parameters(params, num_inputs+num_outputs, rnd); randomize_parameters(params, num_inputs+num_outputs, rnd);
filters = alias_tensor(num_filters, sub.get_output().k(), nr, nc); filters = alias_tensor(_num_filters, sub.get_output().k(), _nr, _nc);
biases = alias_tensor(1,num_filters); biases = alias_tensor(1,_num_filters);
// set the initial bias values to zero // set the initial bias values to zero
biases(params,filters.size()) = 0; biases(params,filters.size()) = 0;
...@@ -95,8 +111,8 @@ namespace dlib ...@@ -95,8 +111,8 @@ namespace dlib
conv(output, conv(output,
sub.get_output(), sub.get_output(),
filters(params,0), filters(params,0),
stride_y, _stride_y,
stride_x); _stride_x);
tt::add(1,output,1,biases(params,filters.size())); tt::add(1,output,1,biases(params,filters.size()));
} }
...@@ -118,11 +134,11 @@ namespace dlib ...@@ -118,11 +134,11 @@ namespace dlib
{ {
serialize("con_", out); serialize("con_", out);
serialize(item.params, out); serialize(item.params, out);
serialize(item.num_filters, out); serialize(item._num_filters, out);
serialize(item.nr, out); serialize(item._nr, out);
serialize(item.nc, out); serialize(item._nc, out);
serialize(item.stride_y, out); serialize(item._stride_y, out);
serialize(item.stride_y, out); serialize(item._stride_y, out);
serialize(item.filters, out); serialize(item.filters, out);
serialize(item.biases, out); serialize(item.biases, out);
} }
...@@ -134,11 +150,11 @@ namespace dlib ...@@ -134,11 +150,11 @@ namespace dlib
if (version != "con_") if (version != "con_")
throw serialization_error("Unexpected version found while deserializing dlib::con_."); throw serialization_error("Unexpected version found while deserializing dlib::con_.");
deserialize(item.params, in); deserialize(item.params, in);
deserialize(item.num_filters, in); deserialize(item._num_filters, in);
deserialize(item.nr, in); deserialize(item._nr, in);
deserialize(item.nc, in); deserialize(item._nc, in);
deserialize(item.stride_y, in); deserialize(item._stride_y, in);
deserialize(item.stride_y, in); deserialize(item._stride_y, in);
deserialize(item.filters, in); deserialize(item.filters, in);
deserialize(item.biases, in); deserialize(item.biases, in);
} }
...@@ -146,11 +162,11 @@ namespace dlib ...@@ -146,11 +162,11 @@ namespace dlib
private: private:
resizable_tensor params; resizable_tensor params;
long num_filters; long _num_filters;
long nr; long _nr;
long nc; long _nc;
int stride_y; int _stride_y;
int stride_x; int _stride_x;
alias_tensor filters, biases; alias_tensor filters, biases;
tt::tensor_conv conv; tt::tensor_conv conv;
......
...@@ -329,6 +329,8 @@ namespace dlib ...@@ -329,6 +329,8 @@ namespace dlib
unsigned long num_outputs unsigned long num_outputs
); );
/*! /*!
requires
- num_outputs > 0
ensures ensures
- #get_num_outputs() == num_outputs - #get_num_outputs() == num_outputs
!*/ !*/
...@@ -363,6 +365,112 @@ namespace dlib ...@@ -363,6 +365,112 @@ namespace dlib
template <typename SUBNET> template <typename SUBNET>
using fc = add_layer<fc_, SUBNET>; using fc = add_layer<fc_, SUBNET>;
// ----------------------------------------------------------------------------------------
class con_
{
/*!
WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_LAYER_ interface defined above.
In particular, it defines a convolution layer that takes an input tensor
(nominally representing an image) and convolves it with a set of filters
and then outputs the results.
!*/
public:
con_(
);
/*!
ensures
- #num_filters() == 1
- #nr() == 3
- #nc() == 3
- #stride_y() == 1
- #stride_x() == 1
!*/
con_(
long num_filters_,
long nr_,
long nc_,
int stride_y_ = 1,
int stride_x_ = 1
);
/*!
requires
- num_filters_ > 0
- nr_ > 0
- nc_ > 0
- stride_y_ > 0
- stride_x_ > 0
ensures
- #num_filters() == num_filters_
- #nr() == nr_
- #nc() == nc_
- #stride_y() == stride_y_
- #stride_x() == stride_x_
!*/
long num_filters(
) const;
/*!
ensures
- returns the number of filters contained in this layer. The k dimension
of the output tensors produced by this layer will be equal to the number
of filters.
!*/
long nr(
) const;
/*!
ensures
- returns the number of rows in the filters in this layer.
!*/
long nc(
) const;
/*!
ensures
- returns the number of columns in the filters in this layer.
!*/
long stride_y(
) const;
/*!
ensures
- returns the vertical stride used when convolving the filters over an
image. That is, each filter will be moved stride_y() pixels down at a
time when it moves over the image.
!*/
long stride_x(
) const;
/*!
ensures
- returns the horizontal stride used when convolving the filters over an
image. That is, each filter will be moved stride_x() pixels right at a
time when it moves over the image.
!*/
template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
const tensor& get_layer_params() const;
tensor& get_layer_params();
/*!
These functions are implemented as described in the EXAMPLE_LAYER_ interface.
!*/
};
void serialize(const con_& item, std::ostream& out);
void deserialize(con_& item, std::istream& in);
/*!
provides serialization support
!*/
template <typename SUBNET>
using con = add_layer<con_, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class relu_ class relu_
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment