Commit 9ca26760 authored by Davis King's avatar Davis King

Renamed scale_prev_ to scale_ and swapped the layers it reads from. Now

the scales come from the immediate predecessor and the tensor to be scaled
from the tag.
parent 0f51dfb9
...@@ -2277,12 +2277,12 @@ namespace dlib ...@@ -2277,12 +2277,12 @@ namespace dlib
template < template <
template<typename> class tag template<typename> class tag
> >
class scale_prev_ class scale_
{ {
public: public:
const static unsigned long id = tag_id<tag>::id; const static unsigned long id = tag_id<tag>::id;
scale_prev_() scale_()
{ {
} }
...@@ -2297,8 +2297,8 @@ namespace dlib ...@@ -2297,8 +2297,8 @@ namespace dlib
template <typename SUBNET> template <typename SUBNET>
void forward(const SUBNET& sub, resizable_tensor& output) void forward(const SUBNET& sub, resizable_tensor& output)
{ {
auto&& src = sub.get_output(); auto&& scales = sub.get_output();
auto&& scales = layer<tag>(sub).get_output(); auto&& src = layer<tag>(sub).get_output();
DLIB_CASSERT(scales.num_samples() == src.num_samples() && DLIB_CASSERT(scales.num_samples() == src.num_samples() &&
scales.k() == src.k() && scales.k() == src.k() &&
scales.nr() == 1 && scales.nr() == 1 &&
...@@ -2311,13 +2311,13 @@ namespace dlib ...@@ -2311,13 +2311,13 @@ namespace dlib
template <typename SUBNET> template <typename SUBNET>
void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
{ {
auto&& src = sub.get_output(); auto&& scales = sub.get_output();
auto&& scales = layer<tag>(sub).get_output(); auto&& src = layer<tag>(sub).get_output();
// The gradient just flows backwards to the two layers that forward() // The gradient just flows backwards to the two layers that forward()
// read from. // read from.
tt::scale_channels(true, sub.get_gradient_input(), gradient_input, scales); tt::scale_channels(true, layer<tag>(sub).get_gradient_input(), gradient_input, scales);
auto&& scales_grad = layer<tag>(sub).get_gradient_input(); auto&& scales_grad = sub.get_gradient_input();
auto sgrad = reshape_scales(scales_grad); auto sgrad = reshape_scales(scales_grad);
tt::dot_prods(true, sgrad, reshape_src(src), reshape_src(gradient_input)); tt::dot_prods(true, sgrad, reshape_src(src), reshape_src(gradient_input));
} }
...@@ -2325,32 +2325,32 @@ namespace dlib ...@@ -2325,32 +2325,32 @@ namespace dlib
const tensor& get_layer_params() const { return params; } const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; } tensor& get_layer_params() { return params; }
friend void serialize(const scale_prev_& item, std::ostream& out) friend void serialize(const scale_& item, std::ostream& out)
{ {
serialize("scale_prev_", out); serialize("scale_", out);
serialize(item.reshape_scales, out); serialize(item.reshape_scales, out);
serialize(item.reshape_src, out); serialize(item.reshape_src, out);
} }
friend void deserialize(scale_prev_& item, std::istream& in) friend void deserialize(scale_& item, std::istream& in)
{ {
std::string version; std::string version;
deserialize(version, in); deserialize(version, in);
if (version != "scale_prev_") if (version != "scale_")
throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::scale_prev_."); throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::scale_.");
deserialize(item.reshape_scales, in); deserialize(item.reshape_scales, in);
deserialize(item.reshape_src, in); deserialize(item.reshape_src, in);
} }
friend std::ostream& operator<<(std::ostream& out, const scale_prev_& item) friend std::ostream& operator<<(std::ostream& out, const scale_& item)
{ {
out << "scale_prev"<<id; out << "scale"<<id;
return out; return out;
} }
friend void to_xml(const scale_prev_& item, std::ostream& out) friend void to_xml(const scale_& item, std::ostream& out)
{ {
out << "<scale_prev tag='"<<id<<"'/>\n"; out << "<scale tag='"<<id<<"'/>\n";
} }
private: private:
...@@ -2363,29 +2363,29 @@ namespace dlib ...@@ -2363,29 +2363,29 @@ namespace dlib
template<typename> class tag, template<typename> class tag,
typename SUBNET typename SUBNET
> >
using scale_prev = add_layer<scale_prev_<tag>, SUBNET>; using scale = add_layer<scale_<tag>, SUBNET>;
template <typename SUBNET> using scale_prev1 = scale_prev<tag1, SUBNET>; template <typename SUBNET> using scale1 = scale<tag1, SUBNET>;
template <typename SUBNET> using scale_prev2 = scale_prev<tag2, SUBNET>; template <typename SUBNET> using scale2 = scale<tag2, SUBNET>;
template <typename SUBNET> using scale_prev3 = scale_prev<tag3, SUBNET>; template <typename SUBNET> using scale3 = scale<tag3, SUBNET>;
template <typename SUBNET> using scale_prev4 = scale_prev<tag4, SUBNET>; template <typename SUBNET> using scale4 = scale<tag4, SUBNET>;
template <typename SUBNET> using scale_prev5 = scale_prev<tag5, SUBNET>; template <typename SUBNET> using scale5 = scale<tag5, SUBNET>;
template <typename SUBNET> using scale_prev6 = scale_prev<tag6, SUBNET>; template <typename SUBNET> using scale6 = scale<tag6, SUBNET>;
template <typename SUBNET> using scale_prev7 = scale_prev<tag7, SUBNET>; template <typename SUBNET> using scale7 = scale<tag7, SUBNET>;
template <typename SUBNET> using scale_prev8 = scale_prev<tag8, SUBNET>; template <typename SUBNET> using scale8 = scale<tag8, SUBNET>;
template <typename SUBNET> using scale_prev9 = scale_prev<tag9, SUBNET>; template <typename SUBNET> using scale9 = scale<tag9, SUBNET>;
template <typename SUBNET> using scale_prev10 = scale_prev<tag10, SUBNET>; template <typename SUBNET> using scale10 = scale<tag10, SUBNET>;
using scale_prev1_ = scale_prev_<tag1>; using scale1_ = scale_<tag1>;
using scale_prev2_ = scale_prev_<tag2>; using scale2_ = scale_<tag2>;
using scale_prev3_ = scale_prev_<tag3>; using scale3_ = scale_<tag3>;
using scale_prev4_ = scale_prev_<tag4>; using scale4_ = scale_<tag4>;
using scale_prev5_ = scale_prev_<tag5>; using scale5_ = scale_<tag5>;
using scale_prev6_ = scale_prev_<tag6>; using scale6_ = scale_<tag6>;
using scale_prev7_ = scale_prev_<tag7>; using scale7_ = scale_<tag7>;
using scale_prev8_ = scale_prev_<tag8>; using scale8_ = scale_<tag8>;
using scale_prev9_ = scale_prev_<tag9>; using scale9_ = scale_<tag9>;
using scale_prev10_ = scale_prev_<tag10>; using scale10_ = scale_<tag10>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -2331,15 +2331,15 @@ namespace dlib ...@@ -2331,15 +2331,15 @@ namespace dlib
template < template <
template<typename> class tag template<typename> class tag
> >
class scale_prev_ class scale_
{ {
/*! /*!
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
defined above. This layer scales the output channels of the previous layer defined above. This layer scales the output channels of the tagged layer
by multiplying it with the output of the tagged layer. To be specific: by multiplying it with the output of the previous layer. To be specific:
- Let INPUT == sub.get_output() - Let INPUT == layer<tag>(sub).get_output()
- Let SCALES == layer<tag>(sub).get_output() - Let SCALES == sub.get_output()
- This layer takes INPUT and SCALES as input. - This layer takes INPUT and SCALES as input.
- The output of this layer has the same dimensions as INPUT. - The output of this layer has the same dimensions as INPUT.
- This layer requires: - This layer requires:
...@@ -2354,7 +2354,7 @@ namespace dlib ...@@ -2354,7 +2354,7 @@ namespace dlib
!*/ !*/
public: public:
scale_prev_( scale_(
); );
template <typename SUBNET> void setup (const SUBNET& sub); template <typename SUBNET> void setup (const SUBNET& sub);
...@@ -2372,29 +2372,29 @@ namespace dlib ...@@ -2372,29 +2372,29 @@ namespace dlib
template<typename> class tag, template<typename> class tag,
typename SUBNET typename SUBNET
> >
using scale_prev = add_layer<scale_prev_<tag>, SUBNET>; using scale = add_layer<scale_<tag>, SUBNET>;
// Here we add some convenient aliases for using scale_prev_ with the tag layers. // Here we add some convenient aliases for using scale_ with the tag layers.
template <typename SUBNET> using scale_prev1 = scale_prev<tag1, SUBNET>; template <typename SUBNET> using scale1 = scale<tag1, SUBNET>;
template <typename SUBNET> using scale_prev2 = scale_prev<tag2, SUBNET>; template <typename SUBNET> using scale2 = scale<tag2, SUBNET>;
template <typename SUBNET> using scale_prev3 = scale_prev<tag3, SUBNET>; template <typename SUBNET> using scale3 = scale<tag3, SUBNET>;
template <typename SUBNET> using scale_prev4 = scale_prev<tag4, SUBNET>; template <typename SUBNET> using scale4 = scale<tag4, SUBNET>;
template <typename SUBNET> using scale_prev5 = scale_prev<tag5, SUBNET>; template <typename SUBNET> using scale5 = scale<tag5, SUBNET>;
template <typename SUBNET> using scale_prev6 = scale_prev<tag6, SUBNET>; template <typename SUBNET> using scale6 = scale<tag6, SUBNET>;
template <typename SUBNET> using scale_prev7 = scale_prev<tag7, SUBNET>; template <typename SUBNET> using scale7 = scale<tag7, SUBNET>;
template <typename SUBNET> using scale_prev8 = scale_prev<tag8, SUBNET>; template <typename SUBNET> using scale8 = scale<tag8, SUBNET>;
template <typename SUBNET> using scale_prev9 = scale_prev<tag9, SUBNET>; template <typename SUBNET> using scale9 = scale<tag9, SUBNET>;
template <typename SUBNET> using scale_prev10 = scale_prev<tag10, SUBNET>; template <typename SUBNET> using scale10 = scale<tag10, SUBNET>;
using scale_prev1_ = scale_prev_<tag1>; using scale1_ = scale_<tag1>;
using scale_prev2_ = scale_prev_<tag2>; using scale2_ = scale_<tag2>;
using scale_prev3_ = scale_prev_<tag3>; using scale3_ = scale_<tag3>;
using scale_prev4_ = scale_prev_<tag4>; using scale4_ = scale_<tag4>;
using scale_prev5_ = scale_prev_<tag5>; using scale5_ = scale_<tag5>;
using scale_prev6_ = scale_prev_<tag6>; using scale6_ = scale_<tag6>;
using scale_prev7_ = scale_prev_<tag7>; using scale7_ = scale_<tag7>;
using scale_prev8_ = scale_prev_<tag8>; using scale8_ = scale_<tag8>;
using scale_prev9_ = scale_prev_<tag9>; using scale9_ = scale_<tag9>;
using scale_prev10_ = scale_prev_<tag10>; using scale10_ = scale_<tag10>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment