Commit 28475b8d authored by Davis King's avatar Davis King

Made computed_output an optional argument to backward_inplace() so there is

symmetry between the non-inplace version. This also enables additional
optimizations in the resulting network.
parent 122f2fa6
...@@ -113,6 +113,15 @@ namespace dlib ...@@ -113,6 +113,15 @@ namespace dlib
return true; return true;
} }
template <typename layer_type, typename SUBNET>
constexpr auto backward_requires_forward_output(
layer_type& layer,
SUBNET& sub
) -> typename alwaysbool<decltype(layer.backward_inplace(rt(),sub.get_gradient_input(),rt()))>::type
{
return false;
}
template <typename layer_type, typename SUBNET> template <typename layer_type, typename SUBNET>
constexpr auto has_inplace_backward( constexpr auto has_inplace_backward(
layer_type& layer, layer_type& layer,
...@@ -140,6 +149,15 @@ namespace dlib ...@@ -140,6 +149,15 @@ namespace dlib
return true; return true;
} }
template <typename layer_type, typename SUBNET>
constexpr auto has_inplace_backward(
layer_type& layer,
SUBNET& sub
) -> typename alwaysbool<decltype(layer.backward_inplace(rt(),sub.get_gradient_input(),rt()))>::type
{
return true;
}
template <typename layer_type, typename SUBNET> template <typename layer_type, typename SUBNET>
constexpr auto is_inplace_layer( constexpr auto is_inplace_layer(
layer_type& layer, layer_type& layer,
...@@ -194,6 +212,18 @@ namespace dlib ...@@ -194,6 +212,18 @@ namespace dlib
layer.backward_inplace(computed_output,gradient_input,sub.get_gradient_input(),params_grad); layer.backward_inplace(computed_output,gradient_input,sub.get_gradient_input(),params_grad);
} }
template <typename layer_type, typename SUBNET>
auto call_layer_backward(
layer_type& layer,
const tensor& ,
const tensor& gradient_input,
SUBNET& sub,
tensor& params_grad
) -> decltype(layer.backward_inplace(gradient_input,sub.get_gradient_input(),params_grad))
{
layer.backward_inplace(gradient_input,sub.get_gradient_input(),params_grad);
}
template <typename layer_type, typename SUBNET> template <typename layer_type, typename SUBNET>
auto call_layer_forward( auto call_layer_forward(
......
...@@ -559,7 +559,6 @@ namespace dlib ...@@ -559,7 +559,6 @@ namespace dlib
} }
void backward_inplace( void backward_inplace(
const tensor& /*computed_output*/,
const tensor& gradient_input, const tensor& gradient_input,
tensor& data_grad, tensor& data_grad,
tensor& /*params_grad*/ tensor& /*params_grad*/
......
...@@ -99,7 +99,7 @@ namespace dlib ...@@ -99,7 +99,7 @@ namespace dlib
to document the interface that a layer object must implement. to document the interface that a layer object must implement.
The central work of defining a layer is implementing the forward and backward The central work of defining a layer is implementing the forward and backward
methods. When you do this you have three options: methods. When you do this you have four options:
- Implement the forward() and backward() methods according to the - Implement the forward() and backward() methods according to the
specification shown below. Do not implement forward_inplace() and specification shown below. Do not implement forward_inplace() and
backward_inplace(). backward_inplace().
...@@ -113,6 +113,12 @@ namespace dlib ...@@ -113,6 +113,12 @@ namespace dlib
according to the specification shown below. Do not implement according to the specification shown below. Do not implement
forward() and backward(). These in-place methods allow some types of forward() and backward(). These in-place methods allow some types of
layers to be implemented more efficiently. layers to be implemented more efficiently.
- Implement the forward_inplace() and backward_inplace() methods
according to the specification shown below, except exclude the
computed_output parameter from backward_inplace(). Doing this will
allow dlib to make some layers execute in-place and therefore run a
little faster and use less memory. Do not implement forward() and
backward().
!*/ !*/
public: public:
...@@ -239,7 +245,7 @@ namespace dlib ...@@ -239,7 +245,7 @@ namespace dlib
!*/ !*/
void backward_inplace( void backward_inplace(
const tensor& computed_output, const tensor& computed_output, // this parameter is optional
const tensor& gradient_input, const tensor& gradient_input,
tensor& data_grad, tensor& data_grad,
tensor& params_grad tensor& params_grad
...@@ -503,7 +509,7 @@ namespace dlib ...@@ -503,7 +509,7 @@ namespace dlib
template <typename SUBNET> void setup (const SUBNET& sub); template <typename SUBNET> void setup (const SUBNET& sub);
void forward_inplace(const tensor& input, tensor& output); void forward_inplace(const tensor& input, tensor& output);
void backward_inplace(const tensor& computed_output, const tensor& gradient_input, tensor& data_grad, tensor& params_grad); void backward_inplace(const tensor& gradient_input, tensor& data_grad, tensor& params_grad);
const tensor& get_layer_params() const; const tensor& get_layer_params() const;
tensor& get_layer_params(); tensor& get_layer_params();
/*! /*!
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment