Commit 0e30960b authored by Davis King's avatar Davis King

Fixed bug in scale_ layer that would trigger when the num_samples() dimension

of a tensor changed.  I.e. if you ran mini-batches with different sample sizes
you would get an assert triggering.  This has been fixed.
parent 28853943
...@@ -2287,11 +2287,8 @@ namespace dlib ...@@ -2287,11 +2287,8 @@ namespace dlib
} }
template <typename SUBNET> template <typename SUBNET>
void setup (const SUBNET& sub) void setup (const SUBNET& /*sub*/)
{ {
auto&& src = sub.get_output();
reshape_scales = alias_tensor(src.num_samples()*src.k());
reshape_src = alias_tensor(src.num_samples()*src.k(),src.nr()*src.nc());
} }
template <typename SUBNET> template <typename SUBNET>
...@@ -2320,6 +2317,12 @@ namespace dlib ...@@ -2320,6 +2317,12 @@ namespace dlib
// read from. // read from.
tt::scale_channels(true, layer<tag>(sub).get_gradient_input(), gradient_input, scales); tt::scale_channels(true, layer<tag>(sub).get_gradient_input(), gradient_input, scales);
if (reshape_src.num_samples() != src.num_samples())
{
reshape_scales = alias_tensor(src.num_samples()*src.k());
reshape_src = alias_tensor(src.num_samples()*src.k(),src.nr()*src.nc());
}
auto&& scales_grad = sub.get_gradient_input(); auto&& scales_grad = sub.get_gradient_input();
auto sgrad = reshape_scales(scales_grad); auto sgrad = reshape_scales(scales_grad);
tt::dot_prods(true, sgrad, reshape_src(src), reshape_src(gradient_input)); tt::dot_prods(true, sgrad, reshape_src(src), reshape_src(gradient_input));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment