Commit 4ef5908b authored by Davis King's avatar Davis King

Pushed the padding parameters into the con_, max_pool_, and avg_pool_

interfaces.  Also changed the default behavior when the stride isn't 1.  Now
the filters will be applied only to the "valid" part of the image.
parent 6bab1f50
...@@ -24,7 +24,9 @@ namespace dlib ...@@ -24,7 +24,9 @@ namespace dlib
long _nr, long _nr,
long _nc, long _nc,
int _stride_y, int _stride_y,
int _stride_x int _stride_x,
int _padding_y = _stride_y!=1? 0 : _nr/2,
int _padding_x = _stride_x!=1? 0 : _nc/2
> >
class con_ class con_
{ {
...@@ -35,9 +37,13 @@ namespace dlib ...@@ -35,9 +37,13 @@ namespace dlib
static_assert(_nc > 0, "The number of columns in a filter must be > 0"); static_assert(_nc > 0, "The number of columns in a filter must be > 0");
static_assert(_stride_y > 0, "The filter stride must be > 0"); static_assert(_stride_y > 0, "The filter stride must be > 0");
static_assert(_stride_x > 0, "The filter stride must be > 0"); static_assert(_stride_x > 0, "The filter stride must be > 0");
static_assert(0 <= _padding_y && _padding_y < _nr, "The padding must be smaller than the filter size.");
static_assert(0 <= _padding_x && _padding_x < _nc, "The padding must be smaller than the filter size.");
con_( con_(
) ) :
padding_y_(_padding_y),
padding_x_(_padding_x)
{} {}
long num_filters() const { return _num_filters; } long num_filters() const { return _num_filters; }
...@@ -45,13 +51,17 @@ namespace dlib ...@@ -45,13 +51,17 @@ namespace dlib
long nc() const { return _nc; } long nc() const { return _nc; }
long stride_y() const { return _stride_y; } long stride_y() const { return _stride_y; }
long stride_x() const { return _stride_x; } long stride_x() const { return _stride_x; }
long padding_y() const { return padding_y_; }
long padding_x() const { return padding_x_; }
con_ ( con_ (
const con_& item const con_& item
) : ) :
params(item.params), params(item.params),
filters(item.filters), filters(item.filters),
biases(item.biases) biases(item.biases),
padding_y_(item.padding_y_),
padding_x_(item.padding_x_)
{ {
// this->conv is non-copyable and basically stateless, so we have to write our // this->conv is non-copyable and basically stateless, so we have to write our
// own copy to avoid trying to copy it and getting an error. // own copy to avoid trying to copy it and getting an error.
...@@ -69,6 +79,8 @@ namespace dlib ...@@ -69,6 +79,8 @@ namespace dlib
params = item.params; params = item.params;
filters = item.filters; filters = item.filters;
biases = item.biases; biases = item.biases;
padding_y_ = item.padding_y_;
padding_x_ = item.padding_x_;
return *this; return *this;
} }
...@@ -98,8 +110,8 @@ namespace dlib ...@@ -98,8 +110,8 @@ namespace dlib
filters(params,0), filters(params,0),
_stride_y, _stride_y,
_stride_x, _stride_x,
_nr/2, padding_y_,
_nc/2 padding_x_
); );
tt::add(1,output,1,biases(params,filters.size())); tt::add(1,output,1,biases(params,filters.size()));
...@@ -120,13 +132,15 @@ namespace dlib ...@@ -120,13 +132,15 @@ namespace dlib
friend void serialize(const con_& item, std::ostream& out) friend void serialize(const con_& item, std::ostream& out)
{ {
serialize("con_", out); serialize("con_2", out);
serialize(item.params, out); serialize(item.params, out);
serialize(_num_filters, out); serialize(_num_filters, out);
serialize(_nr, out); serialize(_nr, out);
serialize(_nc, out); serialize(_nc, out);
serialize(_stride_y, out); serialize(_stride_y, out);
serialize(_stride_x, out); serialize(_stride_x, out);
serialize(item.padding_y_, out);
serialize(item.padding_x_, out);
serialize(item.filters, out); serialize(item.filters, out);
serialize(item.biases, out); serialize(item.biases, out);
} }
...@@ -135,23 +149,44 @@ namespace dlib ...@@ -135,23 +149,44 @@ namespace dlib
{ {
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "con_")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_.");
deserialize(item.params, in);
long num_filters; long num_filters;
long nr; long nr;
long nc; long nc;
int stride_y; int stride_y;
int stride_x; int stride_x;
deserialize(num_filters, in); if (version == "con_")
deserialize(nr, in); {
deserialize(nc, in); deserialize(item.params, in);
deserialize(stride_y, in); deserialize(num_filters, in);
deserialize(stride_x, in); deserialize(nr, in);
deserialize(item.filters, in); deserialize(nc, in);
deserialize(item.biases, in); deserialize(stride_y, in);
deserialize(stride_x, in);
deserialize(item.filters, in);
deserialize(item.biases, in);
item.padding_y_ = nr/2;
item.padding_x_ = nc/2;
}
else if (version == "con_2")
{
deserialize(item.params, in);
deserialize(num_filters, in);
deserialize(nr, in);
deserialize(nc, in);
deserialize(stride_y, in);
deserialize(stride_x, in);
deserialize(item.padding_y_, in);
deserialize(item.padding_x_, in);
deserialize(item.filters, in);
deserialize(item.biases, in);
if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_");
if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_");
}
else
{
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_.");
}
if (num_filters != _num_filters) throw serialization_error("Wrong num_filters found while deserializing dlib::con_"); if (num_filters != _num_filters) throw serialization_error("Wrong num_filters found while deserializing dlib::con_");
if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_"); if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_");
...@@ -169,6 +204,8 @@ namespace dlib ...@@ -169,6 +204,8 @@ namespace dlib
<< ", nc="<<_nc << ", nc="<<_nc
<< ", stride_y="<<_stride_y << ", stride_y="<<_stride_y
<< ", stride_x="<<_stride_x << ", stride_x="<<_stride_x
<< ", padding_y="<<item.padding_y_
<< ", padding_x="<<item.padding_x_
<< ")"; << ")";
return out; return out;
} }
...@@ -181,6 +218,11 @@ namespace dlib ...@@ -181,6 +218,11 @@ namespace dlib
tt::tensor_conv conv; tt::tensor_conv conv;
// These are here only because older versions of con (which you might encounter
// serialized to disk) used different padding settings.
int padding_y_;
int padding_x_;
}; };
template < template <
...@@ -199,7 +241,9 @@ namespace dlib ...@@ -199,7 +241,9 @@ namespace dlib
long _nr, long _nr,
long _nc, long _nc,
int _stride_y, int _stride_y,
int _stride_x int _stride_x,
int _padding_y = _stride_y!=1? 0 : _nr/2,
int _padding_x = _stride_x!=1? 0 : _nc/2
> >
class max_pool_ class max_pool_
{ {
...@@ -207,24 +251,33 @@ namespace dlib ...@@ -207,24 +251,33 @@ namespace dlib
static_assert(_nc > 0, "The number of columns in a filter must be > 0"); static_assert(_nc > 0, "The number of columns in a filter must be > 0");
static_assert(_stride_y > 0, "The filter stride must be > 0"); static_assert(_stride_y > 0, "The filter stride must be > 0");
static_assert(_stride_x > 0, "The filter stride must be > 0"); static_assert(_stride_x > 0, "The filter stride must be > 0");
static_assert(0 <= _padding_y && _padding_y < _nr, "The padding must be smaller than the filter size.");
static_assert(0 <= _padding_x && _padding_x < _nc, "The padding must be smaller than the filter size.");
public: public:
max_pool_( max_pool_(
) {} ) :
padding_y_(_padding_y),
padding_x_(_padding_x)
{}
long nr() const { return _nr; } long nr() const { return _nr; }
long nc() const { return _nc; } long nc() const { return _nc; }
long stride_y() const { return _stride_y; } long stride_y() const { return _stride_y; }
long stride_x() const { return _stride_x; } long stride_x() const { return _stride_x; }
long padding_y() const { return padding_y_; }
long padding_x() const { return padding_x_; }
max_pool_ ( max_pool_ (
const max_pool_& const max_pool_& item
) ) :
padding_y_(item.padding_y_),
padding_x_(item.padding_x_)
{ {
// this->mp is non-copyable so we have to write our own copy to avoid trying to // this->mp is non-copyable so we have to write our own copy to avoid trying to
// copy it and getting an error. // copy it and getting an error.
mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x, _nr/2, _nc/2); mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x, padding_y_, padding_x_);
} }
max_pool_& operator= ( max_pool_& operator= (
...@@ -234,16 +287,19 @@ namespace dlib ...@@ -234,16 +287,19 @@ namespace dlib
if (this == &item) if (this == &item)
return *this; return *this;
padding_y_ = item.padding_y_;
padding_x_ = item.padding_x_;
// this->mp is non-copyable so we have to write our own copy to avoid trying to // this->mp is non-copyable so we have to write our own copy to avoid trying to
// copy it and getting an error. // copy it and getting an error.
mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x, _nr/2, _nc/2); mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x, padding_y_, padding_x_);
return *this; return *this;
} }
template <typename SUBNET> template <typename SUBNET>
void setup (const SUBNET& /*sub*/) void setup (const SUBNET& /*sub*/)
{ {
mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x, _nr/2, _nc/2); mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x, padding_y_, padding_x_);
} }
template <typename SUBNET> template <typename SUBNET>
...@@ -263,35 +319,55 @@ namespace dlib ...@@ -263,35 +319,55 @@ namespace dlib
friend void serialize(const max_pool_& item, std::ostream& out) friend void serialize(const max_pool_& item, std::ostream& out)
{ {
serialize("max_pool_", out); serialize("max_pool_2", out);
serialize(_nr, out); serialize(_nr, out);
serialize(_nc, out); serialize(_nc, out);
serialize(_stride_y, out); serialize(_stride_y, out);
serialize(_stride_x, out); serialize(_stride_x, out);
serialize(item.padding_y_, out);
serialize(item.padding_x_, out);
} }
friend void deserialize(max_pool_& item, std::istream& in) friend void deserialize(max_pool_& item, std::istream& in)
{ {
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "max_pool_")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::max_pool_.");
item.mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x, _nr/2, _nc/2);
long nr; long nr;
long nc; long nc;
int stride_y; int stride_y;
int stride_x; int stride_x;
if (version == "max_pool_")
{
deserialize(nr, in);
deserialize(nc, in);
deserialize(stride_y, in);
deserialize(stride_x, in);
item.padding_y_ = nr/2;
item.padding_x_ = nc/2;
}
else if (version == "max_pool_2")
{
deserialize(nr, in);
deserialize(nc, in);
deserialize(stride_y, in);
deserialize(stride_x, in);
deserialize(item.padding_y_, in);
deserialize(item.padding_x_, in);
if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::max_pool_");
if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::max_pool_");
}
else
{
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::max_pool_.");
}
deserialize(nr, in);
deserialize(nc, in);
deserialize(stride_y, in);
deserialize(stride_x, in);
if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::max_pool_"); if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::max_pool_");
if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::max_pool_"); if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::max_pool_");
if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::max_pool_"); if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::max_pool_");
if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::max_pool_"); if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::max_pool_");
item.mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x, item.padding_y_, item.padding_x_);
} }
friend std::ostream& operator<<(std::ostream& out, const max_pool_& item) friend std::ostream& operator<<(std::ostream& out, const max_pool_& item)
...@@ -301,6 +377,8 @@ namespace dlib ...@@ -301,6 +377,8 @@ namespace dlib
<< ", nc="<<_nc << ", nc="<<_nc
<< ", stride_y="<<_stride_y << ", stride_y="<<_stride_y
<< ", stride_x="<<_stride_x << ", stride_x="<<_stride_x
<< ", padding_y="<<item.padding_y_
<< ", padding_x="<<item.padding_x_
<< ")"; << ")";
return out; return out;
} }
...@@ -311,6 +389,9 @@ namespace dlib ...@@ -311,6 +389,9 @@ namespace dlib
tt::pooling mp; tt::pooling mp;
resizable_tensor params; resizable_tensor params;
int padding_y_;
int padding_x_;
}; };
template < template <
...@@ -328,7 +409,9 @@ namespace dlib ...@@ -328,7 +409,9 @@ namespace dlib
long _nr, long _nr,
long _nc, long _nc,
int _stride_y, int _stride_y,
int _stride_x int _stride_x,
int _padding_y = _stride_y!=1? 0 : _nr/2,
int _padding_x = _stride_x!=1? 0 : _nc/2
> >
class avg_pool_ class avg_pool_
{ {
...@@ -337,22 +420,31 @@ namespace dlib ...@@ -337,22 +420,31 @@ namespace dlib
static_assert(_nc > 0, "The number of columns in a filter must be > 0"); static_assert(_nc > 0, "The number of columns in a filter must be > 0");
static_assert(_stride_y > 0, "The filter stride must be > 0"); static_assert(_stride_y > 0, "The filter stride must be > 0");
static_assert(_stride_x > 0, "The filter stride must be > 0"); static_assert(_stride_x > 0, "The filter stride must be > 0");
static_assert(0 <= _padding_y && _padding_y < _nr, "The padding must be smaller than the filter size.");
static_assert(0 <= _padding_x && _padding_x < _nc, "The padding must be smaller than the filter size.");
avg_pool_( avg_pool_(
) {} ) :
padding_y_(_padding_y),
padding_x_(_padding_x)
{}
long nr() const { return _nr; } long nr() const { return _nr; }
long nc() const { return _nc; } long nc() const { return _nc; }
long stride_y() const { return _stride_y; } long stride_y() const { return _stride_y; }
long stride_x() const { return _stride_x; } long stride_x() const { return _stride_x; }
long padding_y() const { return padding_y_; }
long padding_x() const { return padding_x_; }
avg_pool_ ( avg_pool_ (
const avg_pool_& const avg_pool_& item
) ) :
padding_y_(item.padding_y_),
padding_x_(item.padding_x_)
{ {
// this->ap is non-copyable so we have to write our own copy to avoid trying to // this->ap is non-copyable so we have to write our own copy to avoid trying to
// copy it and getting an error. // copy it and getting an error.
ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x, _nr/2, _nc/2); ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x, padding_y_, padding_x_);
} }
avg_pool_& operator= ( avg_pool_& operator= (
...@@ -362,16 +454,19 @@ namespace dlib ...@@ -362,16 +454,19 @@ namespace dlib
if (this == &item) if (this == &item)
return *this; return *this;
padding_y_ = item.padding_y_;
padding_x_ = item.padding_x_;
// this->ap is non-copyable so we have to write our own copy to avoid trying to // this->ap is non-copyable so we have to write our own copy to avoid trying to
// copy it and getting an error. // copy it and getting an error.
ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x, _nr/2, _nc/2); ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x, padding_y_, padding_x_);
return *this; return *this;
} }
template <typename SUBNET> template <typename SUBNET>
void setup (const SUBNET& /*sub*/) void setup (const SUBNET& /*sub*/)
{ {
ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x, _nr/2, _nc/2); ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x, padding_y_, padding_x_);
} }
template <typename SUBNET> template <typename SUBNET>
...@@ -391,35 +486,55 @@ namespace dlib ...@@ -391,35 +486,55 @@ namespace dlib
friend void serialize(const avg_pool_& item, std::ostream& out) friend void serialize(const avg_pool_& item, std::ostream& out)
{ {
serialize("avg_pool_", out); serialize("avg_pool_2", out);
serialize(_nr, out); serialize(_nr, out);
serialize(_nc, out); serialize(_nc, out);
serialize(_stride_y, out); serialize(_stride_y, out);
serialize(_stride_x, out); serialize(_stride_x, out);
serialize(item.padding_y_, out);
serialize(item.padding_x_, out);
} }
friend void deserialize(avg_pool_& item, std::istream& in) friend void deserialize(avg_pool_& item, std::istream& in)
{ {
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "avg_pool_")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::avg_pool_.");
item.ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x, _nr/2, _nc/2);
long nr; long nr;
long nc; long nc;
int stride_y; int stride_y;
int stride_x; int stride_x;
if (version == "avg_pool_")
{
deserialize(nr, in);
deserialize(nc, in);
deserialize(stride_y, in);
deserialize(stride_x, in);
item.padding_y_ = nr/2;
item.padding_x_ = nc/2;
}
else if (version == "avg_pool_2")
{
deserialize(nr, in);
deserialize(nc, in);
deserialize(stride_y, in);
deserialize(stride_x, in);
deserialize(item.padding_y_, in);
deserialize(item.padding_x_, in);
if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::avg_pool_");
if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::avg_pool_");
}
else
{
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::avg_pool_.");
}
deserialize(nr, in);
deserialize(nc, in);
deserialize(stride_y, in);
deserialize(stride_x, in);
if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::avg_pool_"); if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::avg_pool_");
if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::avg_pool_"); if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::avg_pool_");
if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::avg_pool_"); if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::avg_pool_");
if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::avg_pool_"); if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::avg_pool_");
item.ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x, item.padding_y_, item.padding_x_);
} }
friend std::ostream& operator<<(std::ostream& out, const avg_pool_& item) friend std::ostream& operator<<(std::ostream& out, const avg_pool_& item)
...@@ -429,6 +544,8 @@ namespace dlib ...@@ -429,6 +544,8 @@ namespace dlib
<< ", nc="<<_nc << ", nc="<<_nc
<< ", stride_y="<<_stride_y << ", stride_y="<<_stride_y
<< ", stride_x="<<_stride_x << ", stride_x="<<_stride_x
<< ", padding_y="<<item.padding_y_
<< ", padding_x="<<item.padding_x_
<< ")"; << ")";
return out; return out;
} }
...@@ -436,6 +553,9 @@ namespace dlib ...@@ -436,6 +553,9 @@ namespace dlib
tt::pooling ap; tt::pooling ap;
resizable_tensor params; resizable_tensor params;
int padding_y_;
int padding_x_;
}; };
template < template <
......
...@@ -416,13 +416,18 @@ namespace dlib ...@@ -416,13 +416,18 @@ namespace dlib
long _nr, long _nr,
long _nc, long _nc,
int _stride_y, int _stride_y,
int _stride_x int _stride_x,
int _padding_y = _stride_y!=1? 0 : _nr/2,
int _padding_x = _stride_x!=1? 0 : _nc/2
> >
class con_ class con_
{ {
/*! /*!
REQUIREMENTS ON TEMPLATE ARGUMENTS REQUIREMENTS ON TEMPLATE ARGUMENTS
All of them must be > 0. All of them must be > 0.
Also, we require that:
- 0 <= _padding_y && _padding_y < _nr
- 0 <= _padding_x && _padding_x < _nc
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
...@@ -434,8 +439,8 @@ namespace dlib ...@@ -434,8 +439,8 @@ namespace dlib
IN be the input tensor and OUT the output tensor): IN be the input tensor and OUT the output tensor):
- OUT.num_samples() == IN.num_samples() - OUT.num_samples() == IN.num_samples()
- OUT.k() == num_filters() - OUT.k() == num_filters()
- OUT.nr() == 1+(IN.nr()-nr()%2)/stride_y() - OUT.nr() == 1+(IN.nr() + 2*padding_y() - nr())/stride_y()
- OUT.nc() == 1+(IN.nc()-nc()%2)/stride_x() - OUT.nc() == 1+(IN.nc() + 2*padding_x() - nc())/stride_x()
!*/ !*/
public: public:
...@@ -448,6 +453,8 @@ namespace dlib ...@@ -448,6 +453,8 @@ namespace dlib
- #nc() == _nc - #nc() == _nc
- #stride_y() == _stride_y - #stride_y() == _stride_y
- #stride_x() == _stride_x - #stride_x() == _stride_x
- #padding_y() == _padding_y
- #padding_x() == _padding_x
!*/ !*/
long num_filters( long num_filters(
...@@ -491,6 +498,22 @@ namespace dlib ...@@ -491,6 +498,22 @@ namespace dlib
time when it moves over the image. time when it moves over the image.
!*/ !*/
long padding_y(
) const;
/*!
ensures
- returns the number of pixels of zero padding added to the top and bottom
sides of the image.
!*/
long padding_x(
) const;
/*!
ensures
- returns the number of pixels of zero padding added to the left and right
sides of the image.
!*/
template <typename SUBNET> void setup (const SUBNET& sub); template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output); template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
...@@ -813,13 +836,18 @@ namespace dlib ...@@ -813,13 +836,18 @@ namespace dlib
long _nr, long _nr,
long _nc, long _nc,
int _stride_y, int _stride_y,
int _stride_x int _stride_x,
int _padding_y = _stride_y!=1? 0 : _nr/2,
int _padding_x = _stride_x!=1? 0 : _nc/2
> >
class max_pool_ class max_pool_
{ {
/*! /*!
REQUIREMENTS ON TEMPLATE ARGUMENTS REQUIREMENTS ON TEMPLATE ARGUMENTS
All of them must be > 0. All of them must be > 0.
Also, we require that:
- 0 <= _padding_y && _padding_y < _nr
- 0 <= _padding_x && _padding_x < _nc
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
...@@ -832,14 +860,14 @@ namespace dlib ...@@ -832,14 +860,14 @@ namespace dlib
then OUT is defined as follows: then OUT is defined as follows:
- OUT.num_samples() == IN.num_samples() - OUT.num_samples() == IN.num_samples()
- OUT.k() == IN.k() - OUT.k() == IN.k()
- OUT.nr() == 1+(IN.nr()-nr()%2)/stride_y() - OUT.nr() == 1+(IN.nr() + 2*padding_y() - nr())/stride_y()
- OUT.nc() == 1+(IN.nc()-nc()%2)/stride_x() - OUT.nc() == 1+(IN.nc() + 2*padding_x() - nc())/stride_x()
- for all valid s, k, r, and c: - for all valid s, k, r, and c:
- image_plane(OUT,s,k)(r,c) == max(subm_clipped(image_plane(IN,s,k), - image_plane(OUT,s,k)(r,c) == max(subm_clipped(image_plane(IN,s,k),
centered_rect(r*stride_y(), centered_rect(x*stride_x() + nc()/2 - padding_x(),
c*stride_x(), y*stride_y() + nr()/2 - padding_y(),
nr(), nc(),
nc()))) nr())))
!*/ !*/
public: public:
...@@ -852,6 +880,8 @@ namespace dlib ...@@ -852,6 +880,8 @@ namespace dlib
- #nc() == _nc - #nc() == _nc
- #stride_y() == _stride_y - #stride_y() == _stride_y
- #stride_x() == _stride_x - #stride_x() == _stride_x
- #padding_y() == _padding_y
- #padding_x() == _padding_x
!*/ !*/
long nr( long nr(
...@@ -886,6 +916,22 @@ namespace dlib ...@@ -886,6 +916,22 @@ namespace dlib
at a time when it moves over the image. at a time when it moves over the image.
!*/ !*/
long padding_y(
) const;
/*!
ensures
- returns the number of pixels of zero padding added to the top and bottom
sides of the image.
!*/
long padding_x(
) const;
/*!
ensures
- returns the number of pixels of zero padding added to the left and right
sides of the image.
!*/
template <typename SUBNET> void setup (const SUBNET& sub); template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output); template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& params_grad); template <typename SUBNET> void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
...@@ -913,13 +959,18 @@ namespace dlib ...@@ -913,13 +959,18 @@ namespace dlib
long _nr, long _nr,
long _nc, long _nc,
int _stride_y, int _stride_y,
int _stride_x int _stride_x,
int _padding_y = _stride_y!=1? 0 : _nr/2,
int _padding_x = _stride_x!=1? 0 : _nc/2
> >
class avg_pool_ class avg_pool_
{ {
/*! /*!
REQUIREMENTS ON TEMPLATE ARGUMENTS REQUIREMENTS ON TEMPLATE ARGUMENTS
All of them must be > 0. All of them must be > 0.
Also, we require that:
- 0 <= _padding_y && _padding_y < _nr
- 0 <= _padding_x && _padding_x < _nc
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
...@@ -934,12 +985,14 @@ namespace dlib ...@@ -934,12 +985,14 @@ namespace dlib
- OUT.k() == IN.k() - OUT.k() == IN.k()
- OUT.nr() == 1+(IN.nr()-nr()%2)/stride_y() - OUT.nr() == 1+(IN.nr()-nr()%2)/stride_y()
- OUT.nc() == 1+(IN.nc()-nc()%2)/stride_x() - OUT.nc() == 1+(IN.nc()-nc()%2)/stride_x()
- OUT.nr() == 1+(IN.nr() + 2*padding_y() - nr())/stride_y()
- OUT.nc() == 1+(IN.nc() + 2*padding_x() - nc())/stride_x()
- for all valid s, k, r, and c: - for all valid s, k, r, and c:
- image_plane(OUT,s,k)(r,c) == mean(subm_clipped(image_plane(IN,s,k), - image_plane(OUT,s,k)(r,c) == mean(subm_clipped(image_plane(IN,s,k),
centered_rect(r*stride_y(), centered_rect(x*stride_x() + nc()/2 - padding_x(),
c*stride_x(), y*stride_y() + nr()/2 - padding_y(),
nr(), nc(),
nc())) nr())))
!*/ !*/
public: public:
...@@ -952,6 +1005,8 @@ namespace dlib ...@@ -952,6 +1005,8 @@ namespace dlib
- #nc() == _nc - #nc() == _nc
- #stride_y() == _stride_y - #stride_y() == _stride_y
- #stride_x() == _stride_x - #stride_x() == _stride_x
- #padding_y() == _padding_y
- #padding_x() == _padding_x
!*/ !*/
long nr( long nr(
...@@ -986,6 +1041,22 @@ namespace dlib ...@@ -986,6 +1041,22 @@ namespace dlib
at a time when it moves over the image. at a time when it moves over the image.
!*/ !*/
long padding_y(
) const;
/*!
ensures
- returns the number of pixels of zero padding added to the top and bottom
sides of the image.
!*/
long padding_x(
) const;
/*!
ensures
- returns the number of pixels of zero padding added to the left and right
sides of the image.
!*/
template <typename SUBNET> void setup (const SUBNET& sub); template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output); template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& params_grad); template <typename SUBNET> void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
......
...@@ -1185,7 +1185,7 @@ namespace ...@@ -1185,7 +1185,7 @@ namespace
} }
{ {
print_spinner(); print_spinner();
con_<3,3,3,2,2> l; con_<3,2,2,2,2> l;
DLIB_TEST_MSG(test_layer(l), test_layer(l)); DLIB_TEST_MSG(test_layer(l), test_layer(l));
} }
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment