Commit df7d7f03 authored by Davis King's avatar Davis King

Added max_pool_ layer.

parent 7ae43ae2
......@@ -176,6 +176,126 @@ namespace dlib
template <typename SUBNET>
using con = add_layer<con_, SUBNET>;
// ----------------------------------------------------------------------------------------
class max_pool_
{
public:
max_pool_ (
) :
_nr(3),
_nc(3),
_stride_y(1),
_stride_x(1)
{}
max_pool_(
long nr_,
long nc_,
int stride_y_ = 1,
int stride_x_ = 1
) :
_nr(nr_),
_nc(nc_),
_stride_y(stride_y_),
_stride_x(stride_x_)
{}
long nr() const { return _nr; }
long nc() const { return _nc; }
long stride_y() const { return _stride_y; }
long stride_x() const { return _stride_x; }
max_pool_ (
const max_pool_& item
) :
_nr(item._nr),
_nc(item._nc),
_stride_y(item._stride_y),
_stride_x(item._stride_x)
{
// this->mp is non-copyable so we have to write our own copy to avoid trying to
// copy it and getting an error.
mp.setup(_nr, _nc, _stride_y, _stride_x);
}
max_pool_& operator= (
const max_pool_& item
)
{
if (this == &item)
return *this;
// this->mp is non-copyable so we have to write our own copy to avoid trying to
// copy it and getting an error.
_nr = item._nr;
_nc = item._nc;
_stride_y = item._stride_y;
_stride_x = item._stride_x;
mp.setup(_nr, _nc, _stride_y, _stride_x);
return *this;
}
template <typename SUBNET>
void setup (const SUBNET& /*sub*/)
{
mp.setup(_nr, _nc, _stride_y, _stride_x);
}
template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output)
{
mp(output, sub.get_output());
}
template <typename SUBNET>
void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
{
mp.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input());
}
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
friend void serialize(const max_pool_& item, std::ostream& out)
{
serialize("max_pool_", out);
serialize(item._nr, out);
serialize(item._nc, out);
serialize(item._stride_y, out);
serialize(item._stride_y, out);
}
friend void deserialize(max_pool_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "max_pool_")
throw serialization_error("Unexpected version found while deserializing dlib::max_pool_.");
deserialize(item._nr, in);
deserialize(item._nc, in);
deserialize(item._stride_y, in);
deserialize(item._stride_y, in);
item.mp.setup(item._nr, item._nc, item._stride_y, item._stride_x);
}
private:
long _nr;
long _nc;
int _stride_y;
int _stride_x;
tt::max_pool mp;
resizable_tensor params;
};
template <typename SUBNET>
using max_pool = add_layer<max_pool_, SUBNET>;
// ----------------------------------------------------------------------------------------
class bn_
......
......@@ -522,7 +522,7 @@ namespace dlib
// ----------------------------------------------------------------------------------------
// TODO, add spec for bn_ and affine_ layers.
// TODO, add spec for max_pool_, bn_, and affine_ layers.
// ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment