Commit c48a6af8 authored by Davis King's avatar Davis King

Added a way to get the final gradient with respect to the inputs. Also added a

method to more efficiently give the input gradient in some instances.
parent 3597df5e
...@@ -731,13 +731,22 @@ namespace dlib ...@@ -731,13 +731,22 @@ namespace dlib
return private_get_gradient_input(); return private_get_gradient_input();
} }
const tensor& get_final_data_gradient(
) const { return subnetwork.get_final_data_gradient(); }
template <typename solver_type> template <typename solver_type>
void update(const tensor& x, sstack<solver_type,num_layers>& solvers) void update(const tensor& x, sstack<solver_type,num_layers>& solvers)
{
update(x,private_get_gradient_input(),solvers);
}
template <typename solver_type>
void update(const tensor& x, const tensor& gradient_input, sstack<solver_type,num_layers>& solvers)
{ {
dimpl::subnet_wrapper<subnet_type> wsub(*subnetwork); dimpl::subnet_wrapper<subnet_type> wsub(*subnetwork);
params_grad.copy_size(details.get_layer_params()); params_grad.copy_size(details.get_layer_params());
impl::call_layer_backward(details, private_get_output(), impl::call_layer_backward(details, private_get_output(),
private_get_gradient_input(), wsub, static_cast<tensor&>(params_grad)); gradient_input, wsub, static_cast<tensor&>(params_grad));
// Don't try to adjust the parameters if this layer doesn't have any. // Don't try to adjust the parameters if this layer doesn't have any.
if (params_grad.size() != 0) if (params_grad.size() != 0)
solvers.top()(details, static_cast<const tensor&>(params_grad)); solvers.top()(details, static_cast<const tensor&>(params_grad));
...@@ -1015,13 +1024,22 @@ namespace dlib ...@@ -1015,13 +1024,22 @@ namespace dlib
return private_get_gradient_input(); return private_get_gradient_input();
} }
const tensor& get_final_data_gradient(
) const { return grad_final_ignored; }
template <typename solver_type> template <typename solver_type>
void update(const tensor& x, sstack<solver_type,num_layers>& solvers) void update(const tensor& x, sstack<solver_type,num_layers>& solvers)
{
update(x,private_get_gradient_input(),solvers);
}
template <typename solver_type>
void update(const tensor& x, const tensor& gradient_input, sstack<solver_type,num_layers>& solvers)
{ {
subnet_wrapper wsub(x, grad_final_ignored); subnet_wrapper wsub(x, grad_final_ignored);
params_grad.copy_size(details.get_layer_params()); params_grad.copy_size(details.get_layer_params());
impl::call_layer_backward(details, private_get_output(), impl::call_layer_backward(details, private_get_output(),
private_get_gradient_input(), wsub, static_cast<tensor&>(params_grad)); gradient_input, wsub, static_cast<tensor&>(params_grad));
// Don't try to adjust the parameters if this layer doesn't have any. // Don't try to adjust the parameters if this layer doesn't have any.
if (params_grad.size() != 0) if (params_grad.size() != 0)
solvers.top()(details, static_cast<const tensor&>(params_grad)); solvers.top()(details, static_cast<const tensor&>(params_grad));
...@@ -1210,12 +1228,21 @@ namespace dlib ...@@ -1210,12 +1228,21 @@ namespace dlib
return subnetwork.get_gradient_input(); return subnetwork.get_gradient_input();
} }
const tensor& get_final_data_gradient(
) const { return subnetwork.get_final_data_gradient(); }
template <typename solver_type> template <typename solver_type>
void update(const tensor& x, sstack<solver_type,num_layers>& solvers) void update(const tensor& x, sstack<solver_type,num_layers>& solvers)
{ {
subnetwork.update(x,solvers.pop()); subnetwork.update(x,solvers.pop());
} }
template <typename solver_type>
void update(const tensor& x, const tensor& gradient_input, sstack<solver_type,num_layers>& solvers)
{
subnetwork.update(x,gradient_input,solvers.pop());
}
const subnet_type& subnet() const { return subnetwork; } const subnet_type& subnet() const { return subnetwork; }
subnet_type& subnet() { return subnetwork; } subnet_type& subnet() { return subnetwork; }
...@@ -1346,6 +1373,9 @@ namespace dlib ...@@ -1346,6 +1373,9 @@ namespace dlib
return cached_output; return cached_output;
} }
const tensor& get_final_data_gradient(
) const { return grad_final_ignored; }
tensor& get_gradient_input() tensor& get_gradient_input()
{ {
if (!have_same_dimensions(cached_output, grad_final_ignored)) if (!have_same_dimensions(cached_output, grad_final_ignored))
...@@ -1362,6 +1392,12 @@ namespace dlib ...@@ -1362,6 +1392,12 @@ namespace dlib
// nothing to update // nothing to update
} }
template <typename solver_type>
void update(const tensor& /*x*/, const tensor& gradient_input, sstack<solver_type,num_layers>& /*solvers*/)
{
// nothing to update
}
const subnet_type& subnet() const { return input_layer; } const subnet_type& subnet() const { return input_layer; }
subnet_type& subnet() { return input_layer; } subnet_type& subnet() { return input_layer; }
...@@ -1881,12 +1917,24 @@ namespace dlib ...@@ -1881,12 +1917,24 @@ namespace dlib
return layer<TAG_TYPE>(subnetwork).get_gradient_input(); return layer<TAG_TYPE>(subnetwork).get_gradient_input();
} }
const tensor& get_final_data_gradient(
) const
{
return subnetwork.get_final_data_gradient();
}
template <typename solver_type> template <typename solver_type>
void update(const tensor& x, sstack<solver_type,num_layers>& solvers) void update(const tensor& x, sstack<solver_type,num_layers>& solvers)
{ {
subnetwork.update(x,solvers.pop()); subnetwork.update(x,solvers.pop());
} }
template <typename solver_type>
void update(const tensor& x, const tensor& gradient_input, sstack<solver_type,num_layers>& solvers)
{
subnetwork.update(x,gradient_input,solvers.pop());
}
const subnet_type& subnet() const const subnet_type& subnet() const
{ {
return subnetwork; return subnetwork;
......
...@@ -393,6 +393,16 @@ namespace dlib ...@@ -393,6 +393,16 @@ namespace dlib
update() method. update() method.
!*/ !*/
const tensor& get_final_data_gradient(
) const;
/*!
ensures
- if update() has been called to back-propagate a gradient through this
network then you can call get_final_data_gradient() to obtain the last
gradient computed. That is, this function returns the gradient of the
network with respect to its inputs.
!*/
template <typename solver_type> template <typename solver_type>
void update( void update(
const tensor& x, const tensor& x,
...@@ -412,6 +422,38 @@ namespace dlib ...@@ -412,6 +422,38 @@ namespace dlib
- Back propagates the error gradient, get_gradient_input(), through this - Back propagates the error gradient, get_gradient_input(), through this
network and uses the provided solvers to update the network parameters. network and uses the provided solvers to update the network parameters.
- All elements of #get_gradient_input() are set to 0. - All elements of #get_gradient_input() are set to 0.
- have_same_dimensions(#get_final_data_gradient(), x) == true
- #get_final_data_gradient() contains the gradient of the network with
respect to x.
!*/
template <typename solver_type>
void update(
const tensor& x,
const tensor& gradient_input,
sstack<solver_type,num_layers>& solvers
);
/*!
requires
- forward(x) was called to forward propagate x though the network.
Moreover, this was the most recent call to forward() and x has not been
subsequently modified in any way.
- have_same_dimensions(gradient_input, get_output()) == true
- This instance of solvers has only ever been used with this network. That
is, if you want to call update() on some other neural network object then
you must not reuse the same solvers object.
ensures
- This function is identical to the version of update() defined immediately
above except that it back-propagates gradient_input through the network
instead of get_gradient_input(). Therefore, this version of update is
equivalent to performing:
get_gradient_input() = gradient_input;
update(x,solvers);
Except that calling update(x,gradient_input,solvers) avoids the copy
and is therefore slightly more efficient.
- All elements of #get_gradient_input() are set to 0.
- #get_final_data_gradient() contains the gradient of the network with
respect to x.
!*/ !*/
void clean( void clean(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment