Commit f6695521 authored by Davis King's avatar Davis King

Added max_pool_everything and avg_pool_everything.

parent 6d2495a2
...@@ -247,12 +247,14 @@ namespace dlib ...@@ -247,12 +247,14 @@ namespace dlib
> >
class max_pool_ class max_pool_
{ {
static_assert(_nr > 0, "The number of rows in a filter must be > 0"); static_assert(_nr >= 0, "The number of rows in a filter must be >= 0");
static_assert(_nc > 0, "The number of columns in a filter must be > 0"); static_assert(_nc >= 0, "The number of columns in a filter must be >= 0");
static_assert(_stride_y > 0, "The filter stride must be > 0"); static_assert(_stride_y > 0, "The filter stride must be > 0");
static_assert(_stride_x > 0, "The filter stride must be > 0"); static_assert(_stride_x > 0, "The filter stride must be > 0");
static_assert(0 <= _padding_y && _padding_y < _nr, "The padding must be smaller than the filter size."); static_assert(0 <= _padding_y && (_nr==0 && _padding_y == 0 || _nr!=0 && _padding_y < _nr),
static_assert(0 <= _padding_x && _padding_x < _nc, "The padding must be smaller than the filter size."); "The padding must be smaller than the filter size, unless the filters size is 0.");
static_assert(0 <= _padding_x && (_nc==0 && _padding_x == 0 || _nc!=0 && _padding_x < _nc),
"The padding must be smaller than the filter size, unless the filters size is 0.");
public: public:
...@@ -277,7 +279,6 @@ namespace dlib ...@@ -277,7 +279,6 @@ namespace dlib
{ {
// this->mp is non-copyable so we have to write our own copy to avoid trying to // this->mp is non-copyable so we have to write our own copy to avoid trying to
// copy it and getting an error. // copy it and getting an error.
mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x, padding_y_, padding_x_);
} }
max_pool_& operator= ( max_pool_& operator= (
...@@ -292,25 +293,31 @@ namespace dlib ...@@ -292,25 +293,31 @@ namespace dlib
// this->mp is non-copyable so we have to write our own copy to avoid trying to // this->mp is non-copyable so we have to write our own copy to avoid trying to
// copy it and getting an error. // copy it and getting an error.
mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x, padding_y_, padding_x_);
return *this; return *this;
} }
template <typename SUBNET> template <typename SUBNET>
void setup (const SUBNET& /*sub*/) void setup (const SUBNET& /*sub*/)
{ {
mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x, padding_y_, padding_x_);
} }
template <typename SUBNET> template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output) void forward(const SUBNET& sub, resizable_tensor& output)
{ {
mp.setup_avg_pooling(_nr!=0?_nr:sub.get_output().nr(),
_nc!=0?_nc:sub.get_output().nc(),
_stride_y, _stride_x, padding_y_, padding_x_);
mp(output, sub.get_output()); mp(output, sub.get_output());
} }
template <typename SUBNET> template <typename SUBNET>
void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
{ {
mp.setup_avg_pooling(_nr!=0?_nr:sub.get_output().nr(),
_nc!=0?_nc:sub.get_output().nc(),
_stride_y, _stride_x, padding_y_, padding_x_);
mp.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input()); mp.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input());
} }
...@@ -365,9 +372,6 @@ namespace dlib ...@@ -365,9 +372,6 @@ namespace dlib
if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::max_pool_"); if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::max_pool_");
if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::max_pool_"); if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::max_pool_");
if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::max_pool_"); if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::max_pool_");
item.mp.setup_max_pooling(_nr, _nc, _stride_y, _stride_x, item.padding_y_, item.padding_x_);
} }
friend std::ostream& operator<<(std::ostream& out, const max_pool_& item) friend std::ostream& operator<<(std::ostream& out, const max_pool_& item)
...@@ -403,6 +407,11 @@ namespace dlib ...@@ -403,6 +407,11 @@ namespace dlib
> >
using max_pool = add_layer<max_pool_<nr,nc,stride_y,stride_x>, SUBNET>; using max_pool = add_layer<max_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
template <
typename SUBNET
>
using max_pool_everything = add_layer<max_pool_<0,0,1,1>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -416,12 +425,14 @@ namespace dlib ...@@ -416,12 +425,14 @@ namespace dlib
class avg_pool_ class avg_pool_
{ {
public: public:
static_assert(_nr > 0, "The number of rows in a filter must be > 0"); static_assert(_nr >= 0, "The number of rows in a filter must be >= 0");
static_assert(_nc > 0, "The number of columns in a filter must be > 0"); static_assert(_nc >= 0, "The number of columns in a filter must be >= 0");
static_assert(_stride_y > 0, "The filter stride must be > 0"); static_assert(_stride_y > 0, "The filter stride must be > 0");
static_assert(_stride_x > 0, "The filter stride must be > 0"); static_assert(_stride_x > 0, "The filter stride must be > 0");
static_assert(0 <= _padding_y && _padding_y < _nr, "The padding must be smaller than the filter size."); static_assert(0 <= _padding_y && (_nr==0 && _padding_y == 0 || _nr!=0 && _padding_y < _nr),
static_assert(0 <= _padding_x && _padding_x < _nc, "The padding must be smaller than the filter size."); "The padding must be smaller than the filter size, unless the filters size is 0.");
static_assert(0 <= _padding_x && (_nc==0 && _padding_x == 0 || _nc!=0 && _padding_x < _nc),
"The padding must be smaller than the filter size, unless the filters size is 0.");
avg_pool_( avg_pool_(
) : ) :
...@@ -444,7 +455,6 @@ namespace dlib ...@@ -444,7 +455,6 @@ namespace dlib
{ {
// this->ap is non-copyable so we have to write our own copy to avoid trying to // this->ap is non-copyable so we have to write our own copy to avoid trying to
// copy it and getting an error. // copy it and getting an error.
ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x, padding_y_, padding_x_);
} }
avg_pool_& operator= ( avg_pool_& operator= (
...@@ -459,25 +469,31 @@ namespace dlib ...@@ -459,25 +469,31 @@ namespace dlib
// this->ap is non-copyable so we have to write our own copy to avoid trying to // this->ap is non-copyable so we have to write our own copy to avoid trying to
// copy it and getting an error. // copy it and getting an error.
ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x, padding_y_, padding_x_);
return *this; return *this;
} }
template <typename SUBNET> template <typename SUBNET>
void setup (const SUBNET& /*sub*/) void setup (const SUBNET& /*sub*/)
{ {
ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x, padding_y_, padding_x_);
} }
template <typename SUBNET> template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output) void forward(const SUBNET& sub, resizable_tensor& output)
{ {
ap.setup_avg_pooling(_nr!=0?_nr:sub.get_output().nr(),
_nc!=0?_nc:sub.get_output().nc(),
_stride_y, _stride_x, padding_y_, padding_x_);
ap(output, sub.get_output()); ap(output, sub.get_output());
} }
template <typename SUBNET> template <typename SUBNET>
void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
{ {
ap.setup_avg_pooling(_nr!=0?_nr:sub.get_output().nr(),
_nc!=0?_nc:sub.get_output().nc(),
_stride_y, _stride_x, padding_y_, padding_x_);
ap.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input()); ap.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input());
} }
...@@ -533,8 +549,6 @@ namespace dlib ...@@ -533,8 +549,6 @@ namespace dlib
if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::avg_pool_"); if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::avg_pool_");
if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::avg_pool_"); if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::avg_pool_");
if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::avg_pool_"); if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::avg_pool_");
item.ap.setup_avg_pooling(_nr, _nc, _stride_y, _stride_x, item.padding_y_, item.padding_x_);
} }
friend std::ostream& operator<<(std::ostream& out, const avg_pool_& item) friend std::ostream& operator<<(std::ostream& out, const avg_pool_& item)
...@@ -567,6 +581,11 @@ namespace dlib ...@@ -567,6 +581,11 @@ namespace dlib
> >
using avg_pool = add_layer<avg_pool_<nr,nc,stride_y,stride_x>, SUBNET>; using avg_pool = add_layer<avg_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
template <
typename SUBNET
>
using avg_pool_everything = add_layer<avg_pool_<0,0,1,1>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
enum layer_mode enum layer_mode
......
...@@ -844,10 +844,20 @@ namespace dlib ...@@ -844,10 +844,20 @@ namespace dlib
{ {
/*! /*!
REQUIREMENTS ON TEMPLATE ARGUMENTS REQUIREMENTS ON TEMPLATE ARGUMENTS
All of them must be > 0. - _nr >= 0
Also, we require that: - _nc >= 0
- 0 <= _padding_y && _padding_y < _nr - _stride_y > 0
- 0 <= _padding_x && _padding_x < _nc - _stride_x > 0
- _padding_y >= 0
- _padding_x >= 0
- if (_nr != 0) then
- _padding_y < _nr
- else
- _padding_y == 0
- if (_nc != 0) then
- _padding_x < _nr
- else
- _padding_x == 0
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
...@@ -856,18 +866,21 @@ namespace dlib ...@@ -856,18 +866,21 @@ namespace dlib
images in an input tensor and outputting, for each channel, the maximum images in an input tensor and outputting, for each channel, the maximum
element within the window. element within the window.
To be precise, if we call the input tensor IN and the output tensor OUT, If _nr == 0 then it means the filter size covers all the rows in the input
then OUT is defined as follows: tensor, similarly for the _nc parameter. To be precise, if we call the
input tensor IN and the output tensor OUT, then OUT is defined as follows:
- let FILT_NR == (nr()==0) ? IN.nr() : nr()
- let FILT_NC == (nc()==0) ? IN.nc() : nc()
- OUT.num_samples() == IN.num_samples() - OUT.num_samples() == IN.num_samples()
- OUT.k() == IN.k() - OUT.k() == IN.k()
- OUT.nr() == 1+(IN.nr() + 2*padding_y() - nr())/stride_y() - OUT.nr() == 1+(IN.nr() + 2*padding_y() - FILT_NR)/stride_y()
- OUT.nc() == 1+(IN.nc() + 2*padding_x() - nc())/stride_x() - OUT.nc() == 1+(IN.nc() + 2*padding_x() - FILT_NC)/stride_x()
- for all valid s, k, r, and c: - for all valid s, k, r, and c:
- image_plane(OUT,s,k)(r,c) == max(subm_clipped(image_plane(IN,s,k), - image_plane(OUT,s,k)(r,c) == max(subm_clipped(image_plane(IN,s,k),
centered_rect(x*stride_x() + nc()/2 - padding_x(), centered_rect(x*stride_x() + FILT_NC/2 - padding_x(),
y*stride_y() + nr()/2 - padding_y(), y*stride_y() + FILT_NR/2 - padding_y(),
nc(), FILT_NC,
nr()))) FILT_NR)))
!*/ !*/
public: public:
...@@ -888,14 +901,16 @@ namespace dlib ...@@ -888,14 +901,16 @@ namespace dlib
) const; ) const;
/*! /*!
ensures ensures
- returns the number of rows in the max pooling window. - returns the number of rows in the pooling window or 0 if the window size
is "the entire input tensor".
!*/ !*/
long nc( long nc(
) const; ) const;
/*! /*!
ensures ensures
- returns the number of columns in the max pooling window. - returns the number of rows in the pooling window or 0 if the window size
is "the entire input tensor".
!*/ !*/
long stride_y( long stride_y(
...@@ -953,6 +968,11 @@ namespace dlib ...@@ -953,6 +968,11 @@ namespace dlib
> >
using max_pool = add_layer<max_pool_<nr,nc,stride_y,stride_x>, SUBNET>; using max_pool = add_layer<max_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
template <
typename SUBNET
>
using max_pool_everything = add_layer<max_pool_<0,0,1,1>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -967,10 +987,20 @@ namespace dlib ...@@ -967,10 +987,20 @@ namespace dlib
{ {
/*! /*!
REQUIREMENTS ON TEMPLATE ARGUMENTS REQUIREMENTS ON TEMPLATE ARGUMENTS
All of them must be > 0. - _nr >= 0
Also, we require that: - _nc >= 0
- 0 <= _padding_y && _padding_y < _nr - _stride_y > 0
- 0 <= _padding_x && _padding_x < _nc - _stride_x > 0
- _padding_y >= 0
- _padding_x >= 0
- if (_nr != 0) then
- _padding_y < _nr
- else
- _padding_y == 0
- if (_nc != 0) then
- _padding_x < _nr
- else
- _padding_x == 0
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
...@@ -979,20 +1009,21 @@ namespace dlib ...@@ -979,20 +1009,21 @@ namespace dlib
over the images in an input tensor and outputting, for each channel, the over the images in an input tensor and outputting, for each channel, the
average element within the window. average element within the window.
To be precise, if we call the input tensor IN and the output tensor OUT, If _nr == 0 then it means the filter size covers all the rows in the input
then OUT is defined as follows: tensor, similarly for the _nc parameter. To be precise, if we call the
input tensor IN and the output tensor OUT, then OUT is defined as follows:
- let FILT_NR == (nr()==0) ? IN.nr() : nr()
- let FILT_NC == (nc()==0) ? IN.nc() : nc()
- OUT.num_samples() == IN.num_samples() - OUT.num_samples() == IN.num_samples()
- OUT.k() == IN.k() - OUT.k() == IN.k()
- OUT.nr() == 1+(IN.nr()-nr()%2)/stride_y() - OUT.nr() == 1+(IN.nr() + 2*padding_y() - FILT_NR)/stride_y()
- OUT.nc() == 1+(IN.nc()-nc()%2)/stride_x() - OUT.nc() == 1+(IN.nc() + 2*padding_x() - FILT_NC)/stride_x()
- OUT.nr() == 1+(IN.nr() + 2*padding_y() - nr())/stride_y()
- OUT.nc() == 1+(IN.nc() + 2*padding_x() - nc())/stride_x()
- for all valid s, k, r, and c: - for all valid s, k, r, and c:
- image_plane(OUT,s,k)(r,c) == mean(subm_clipped(image_plane(IN,s,k), - image_plane(OUT,s,k)(r,c) == mean(subm_clipped(image_plane(IN,s,k),
centered_rect(x*stride_x() + nc()/2 - padding_x(), centered_rect(x*stride_x() + FILT_NC/2 - padding_x(),
y*stride_y() + nr()/2 - padding_y(), y*stride_y() + FILT_NR/2 - padding_y(),
nc(), FILT_NC,
nr()))) FILT_NR)))
!*/ !*/
public: public:
...@@ -1013,14 +1044,16 @@ namespace dlib ...@@ -1013,14 +1044,16 @@ namespace dlib
) const; ) const;
/*! /*!
ensures ensures
- returns the number of rows in the pooling window. - returns the number of rows in the pooling window or 0 if the window size
is "the entire input tensor".
!*/ !*/
long nc( long nc(
) const; ) const;
/*! /*!
ensures ensures
- returns the number of columns in the pooling window. - returns the number of rows in the pooling window or 0 if the window size
is "the entire input tensor".
!*/ !*/
long stride_y( long stride_y(
...@@ -1079,6 +1112,11 @@ namespace dlib ...@@ -1079,6 +1112,11 @@ namespace dlib
> >
using avg_pool = add_layer<avg_pool_<nr,nc,stride_y,stride_x>, SUBNET>; using avg_pool = add_layer<avg_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
template <
typename SUBNET
>
using avg_pool_everything = add_layer<avg_pool_<0,0,1,1>, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class relu_ class relu_
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment