Commit 73a5e943 authored by Davis King's avatar Davis King

Added scale_prev_ layer.

parent c191b376
...@@ -2272,6 +2272,121 @@ namespace dlib ...@@ -2272,6 +2272,121 @@ namespace dlib
using mult_prev9_ = mult_prev_<tag9>; using mult_prev9_ = mult_prev_<tag9>;
using mult_prev10_ = mult_prev_<tag10>; using mult_prev10_ = mult_prev_<tag10>;
// ----------------------------------------------------------------------------------------
template <
template<typename> class tag
>
class scale_prev_
{
public:
const static unsigned long id = tag_id<tag>::id;
scale_prev_()
{
}
template <typename SUBNET>
void setup (const SUBNET& sub)
{
auto&& src = sub.get_output();
reshape_scales = alias_tensor(src.num_samples()*src.k());
reshape_src = alias_tensor(src.num_samples()*src.k(),src.nr()*src.nc());
}
template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output)
{
auto&& src = sub.get_output();
auto&& scales = layer<tag>(sub).get_output();
DLIB_CASSERT(scales.num_samples() == src.num_samples() &&
scales.k() == src.k() &&
scales.nr() == 1 &&
scales.nc() == 1 );
output.copy_size(src);
tt::scale_channels(false, output, src, scales);
}
template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
{
auto&& src = sub.get_output();
auto&& scales = layer<tag>(sub).get_output();
// The gradient just flows backwards to the two layers that forward()
// read from.
tt::scale_channels(true, sub.get_gradient_input(), gradient_input, scales);
auto&& scales_grad = layer<tag>(sub).get_gradient_input();
auto sgrad = reshape_scales(scales_grad);
tt::dot_prods(true, sgrad, reshape_src(src), reshape_src(gradient_input));
}
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
friend void serialize(const scale_prev_& item, std::ostream& out)
{
serialize("scale_prev_", out);
serialize(item.reshape_scales, out);
serialize(item.reshape_src, out);
}
friend void deserialize(scale_prev_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "scale_prev_")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::scale_prev_.");
deserialize(item.reshape_scales, in);
deserialize(item.reshape_src, in);
}
friend std::ostream& operator<<(std::ostream& out, const scale_prev_& item)
{
out << "scale_prev"<<id;
return out;
}
friend void to_xml(const scale_prev_& item, std::ostream& out)
{
out << "<scale_prev tag='"<<id<<"'/>\n";
}
private:
alias_tensor reshape_scales;
alias_tensor reshape_src;
resizable_tensor params;
};
template <
template<typename> class tag,
typename SUBNET
>
using scale_prev = add_layer<scale_prev_<tag>, SUBNET>;
template <typename SUBNET> using scale_prev1 = scale_prev<tag1, SUBNET>;
template <typename SUBNET> using scale_prev2 = scale_prev<tag2, SUBNET>;
template <typename SUBNET> using scale_prev3 = scale_prev<tag3, SUBNET>;
template <typename SUBNET> using scale_prev4 = scale_prev<tag4, SUBNET>;
template <typename SUBNET> using scale_prev5 = scale_prev<tag5, SUBNET>;
template <typename SUBNET> using scale_prev6 = scale_prev<tag6, SUBNET>;
template <typename SUBNET> using scale_prev7 = scale_prev<tag7, SUBNET>;
template <typename SUBNET> using scale_prev8 = scale_prev<tag8, SUBNET>;
template <typename SUBNET> using scale_prev9 = scale_prev<tag9, SUBNET>;
template <typename SUBNET> using scale_prev10 = scale_prev<tag10, SUBNET>;
using scale_prev1_ = scale_prev_<tag1>;
using scale_prev2_ = scale_prev_<tag2>;
using scale_prev3_ = scale_prev_<tag3>;
using scale_prev4_ = scale_prev_<tag4>;
using scale_prev5_ = scale_prev_<tag5>;
using scale_prev6_ = scale_prev_<tag6>;
using scale_prev7_ = scale_prev_<tag7>;
using scale_prev8_ = scale_prev_<tag8>;
using scale_prev9_ = scale_prev_<tag9>;
using scale_prev10_ = scale_prev_<tag10>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class relu_ class relu_
......
...@@ -2326,6 +2326,76 @@ namespace dlib ...@@ -2326,6 +2326,76 @@ namespace dlib
using mult_prev9_ = mult_prev_<tag9>; using mult_prev9_ = mult_prev_<tag9>;
using mult_prev10_ = mult_prev_<tag10>; using mult_prev10_ = mult_prev_<tag10>;
// ----------------------------------------------------------------------------------------
template <
template<typename> class tag
>
class scale_prev_
{
/*!
WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
defined above. This layer scales the output channels of the previous layer
by multiplying it with the output of the tagged layer. To be specific:
- Let INPUT == sub.get_output()
- Let SCALES == layer<tag>(sub).get_output()
- This layer takes INPUT and SCALES as input.
- The output of this layer has the same dimensions as INPUT.
- This layer requires:
- SCALES.num_samples() == INPUT.num_samples()
- SCALES.k() == INPUT.k()
- SCALES.nr() == 1
- SCALES.nc() == 1
- The output tensor is produced by pointwise multiplying SCALES with
INPUT at each spatial location. Therefore, if OUT is the output of
this layer then we would have:
OUT(n,k,r,c) == INPUT(n,k,r,c)*SCALES(n,k)
!*/
public:
scale_prev_(
);
template <typename SUBNET> void setup (const SUBNET& sub);
template <typename SUBNET> void forward(const SUBNET& sub, resizable_tensor& output);
template <typename SUBNET> void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad);
const tensor& get_layer_params() const;
tensor& get_layer_params();
/*!
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
!*/
};
template <
template<typename> class tag,
typename SUBNET
>
using scale_prev = add_layer<scale_prev_<tag>, SUBNET>;
// Here we add some convenient aliases for using scale_prev_ with the tag layers.
template <typename SUBNET> using scale_prev1 = scale_prev<tag1, SUBNET>;
template <typename SUBNET> using scale_prev2 = scale_prev<tag2, SUBNET>;
template <typename SUBNET> using scale_prev3 = scale_prev<tag3, SUBNET>;
template <typename SUBNET> using scale_prev4 = scale_prev<tag4, SUBNET>;
template <typename SUBNET> using scale_prev5 = scale_prev<tag5, SUBNET>;
template <typename SUBNET> using scale_prev6 = scale_prev<tag6, SUBNET>;
template <typename SUBNET> using scale_prev7 = scale_prev<tag7, SUBNET>;
template <typename SUBNET> using scale_prev8 = scale_prev<tag8, SUBNET>;
template <typename SUBNET> using scale_prev9 = scale_prev<tag9, SUBNET>;
template <typename SUBNET> using scale_prev10 = scale_prev<tag10, SUBNET>;
using scale_prev1_ = scale_prev_<tag1>;
using scale_prev2_ = scale_prev_<tag2>;
using scale_prev3_ = scale_prev_<tag3>;
using scale_prev4_ = scale_prev_<tag4>;
using scale_prev5_ = scale_prev_<tag5>;
using scale_prev6_ = scale_prev_<tag6>;
using scale_prev7_ = scale_prev_<tag7>;
using scale_prev8_ = scale_prev_<tag8>;
using scale_prev9_ = scale_prev_<tag9>;
using scale_prev10_ = scale_prev_<tag10>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template< template<
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment