Commit 3597df5e authored by Davis King's avatar Davis King

Made add_layer hold subnetworks though a pointer so that most of a

network is allocated on the heap rather than resulting in really large
stack usage for large networks.
parent 72b250bb
...@@ -535,16 +535,28 @@ namespace dlib ...@@ -535,16 +535,28 @@ namespace dlib
add_layer( add_layer(
): ):
subnetwork(new subnet_type()),
this_layer_setup_called(false), this_layer_setup_called(false),
gradient_input_is_stale(true), gradient_input_is_stale(true),
get_output_and_gradient_input_disabled(false) get_output_and_gradient_input_disabled(false)
{ {
if (this_layer_operates_inplace()) if (this_layer_operates_inplace())
subnetwork.disable_output_and_gradient_getters(); subnetwork->disable_output_and_gradient_getters();
} }
add_layer(const add_layer&) = default; add_layer(const add_layer& item)
add_layer& operator=(const add_layer&) = default; {
details = item.details;
subnetwork.reset(new subnet_type(*item.subnetwork));
this_layer_setup_called = item.this_layer_setup_called;
gradient_input_is_stale = item.gradient_input_is_stale;
get_output_and_gradient_input_disabled = item.get_output_and_gradient_input_disabled;
x_grad = item.x_grad;
cached_output = item.cached_output;
params_grad = item.params_grad;
temp_tensor = item.temp_tensor;
}
add_layer& operator=(const add_layer& item) { add_layer(item).swap(*this); return *this;}
add_layer(add_layer&& item) : add_layer() { swap(item); } add_layer(add_layer&& item) : add_layer() { swap(item); }
add_layer& operator=(add_layer&& item) { swap(item); return *this; } add_layer& operator=(add_layer&& item) { swap(item); return *this; }
...@@ -563,7 +575,7 @@ namespace dlib ...@@ -563,7 +575,7 @@ namespace dlib
add_layer( add_layer(
const add_layer<T,U,E>& item const add_layer<T,U,E>& item
) : ) :
subnetwork(item.subnet()), subnetwork(new subnet_type(item.subnet())),
details(item.layer_details()), details(item.layer_details()),
this_layer_setup_called(item.this_layer_setup_called), this_layer_setup_called(item.this_layer_setup_called),
gradient_input_is_stale(item.gradient_input_is_stale), gradient_input_is_stale(item.gradient_input_is_stale),
...@@ -572,7 +584,7 @@ namespace dlib ...@@ -572,7 +584,7 @@ namespace dlib
cached_output(item.cached_output) cached_output(item.cached_output)
{ {
if (this_layer_operates_inplace()) if (this_layer_operates_inplace())
subnetwork.disable_output_and_gradient_getters(); subnetwork->disable_output_and_gradient_getters();
} }
template <typename ...T> template <typename ...T>
...@@ -581,13 +593,13 @@ namespace dlib ...@@ -581,13 +593,13 @@ namespace dlib
T&& ...args T&& ...args
) : ) :
details(layer_det), details(layer_det),
subnetwork(std::forward<T>(args)...), subnetwork(new subnet_type(std::forward<T>(args)...)),
this_layer_setup_called(false), this_layer_setup_called(false),
gradient_input_is_stale(true), gradient_input_is_stale(true),
get_output_and_gradient_input_disabled(false) get_output_and_gradient_input_disabled(false)
{ {
if (this_layer_operates_inplace()) if (this_layer_operates_inplace())
subnetwork.disable_output_and_gradient_getters(); subnetwork->disable_output_and_gradient_getters();
} }
template <typename ...T> template <typename ...T>
...@@ -596,13 +608,13 @@ namespace dlib ...@@ -596,13 +608,13 @@ namespace dlib
T&& ...args T&& ...args
) : ) :
details(std::move(layer_det)), details(std::move(layer_det)),
subnetwork(std::forward<T>(args)...), subnetwork(new subnet_type(std::forward<T>(args)...)),
this_layer_setup_called(false), this_layer_setup_called(false),
gradient_input_is_stale(true), gradient_input_is_stale(true),
get_output_and_gradient_input_disabled(false) get_output_and_gradient_input_disabled(false)
{ {
if (this_layer_operates_inplace()) if (this_layer_operates_inplace())
subnetwork.disable_output_and_gradient_getters(); subnetwork->disable_output_and_gradient_getters();
} }
template <typename ...T, typename LD, typename ...U> template <typename ...T, typename LD, typename ...U>
...@@ -611,13 +623,13 @@ namespace dlib ...@@ -611,13 +623,13 @@ namespace dlib
T&& ...args T&& ...args
) : ) :
details(tuple_head(layer_det)), details(tuple_head(layer_det)),
subnetwork(tuple_tail(layer_det),std::forward<T>(args)...), subnetwork(new subnet_type(tuple_tail(layer_det),std::forward<T>(args)...)),
this_layer_setup_called(false), this_layer_setup_called(false),
gradient_input_is_stale(true), gradient_input_is_stale(true),
get_output_and_gradient_input_disabled(false) get_output_and_gradient_input_disabled(false)
{ {
if (this_layer_operates_inplace()) if (this_layer_operates_inplace())
subnetwork.disable_output_and_gradient_getters(); subnetwork->disable_output_and_gradient_getters();
} }
template <typename ...T, typename LD, typename ...U> template <typename ...T, typename LD, typename ...U>
...@@ -641,7 +653,7 @@ namespace dlib ...@@ -641,7 +653,7 @@ namespace dlib
resizable_tensor& data resizable_tensor& data
) const ) const
{ {
subnetwork.to_tensor(ibegin,iend,data); subnetwork->to_tensor(ibegin,iend,data);
} }
template <typename input_iterator> template <typename input_iterator>
...@@ -662,8 +674,8 @@ namespace dlib ...@@ -662,8 +674,8 @@ namespace dlib
const tensor& forward(const tensor& x) const tensor& forward(const tensor& x)
{ {
subnetwork.forward(x); subnetwork->forward(x);
const dimpl::subnet_wrapper<subnet_type> wsub(subnetwork); const dimpl::subnet_wrapper<subnet_type> wsub(*subnetwork);
if (!this_layer_setup_called) if (!this_layer_setup_called)
{ {
details.setup(wsub); details.setup(wsub);
...@@ -682,7 +694,7 @@ namespace dlib ...@@ -682,7 +694,7 @@ namespace dlib
tensor& private_get_output() const tensor& private_get_output() const
{ {
if (const_cast<add_layer&>(*this).this_layer_operates_inplace()) if (const_cast<add_layer&>(*this).this_layer_operates_inplace())
return subnetwork.private_get_output(); return subnetwork->private_get_output();
else else
return const_cast<resizable_tensor&>(cached_output); return const_cast<resizable_tensor&>(cached_output);
} }
...@@ -690,7 +702,7 @@ namespace dlib ...@@ -690,7 +702,7 @@ namespace dlib
{ {
if (this_layer_operates_inplace()) if (this_layer_operates_inplace())
{ {
return subnetwork.private_get_gradient_input(); return subnetwork->private_get_gradient_input();
} }
else else
{ {
...@@ -722,19 +734,19 @@ namespace dlib ...@@ -722,19 +734,19 @@ namespace dlib
template <typename solver_type> template <typename solver_type>
void update(const tensor& x, sstack<solver_type,num_layers>& solvers) void update(const tensor& x, sstack<solver_type,num_layers>& solvers)
{ {
dimpl::subnet_wrapper<subnet_type> wsub(subnetwork); dimpl::subnet_wrapper<subnet_type> wsub(*subnetwork);
params_grad.copy_size(details.get_layer_params()); params_grad.copy_size(details.get_layer_params());
impl::call_layer_backward(details, private_get_output(), impl::call_layer_backward(details, private_get_output(),
private_get_gradient_input(), wsub, static_cast<tensor&>(params_grad)); private_get_gradient_input(), wsub, static_cast<tensor&>(params_grad));
// Don't try to adjust the parameters if this layer doesn't have any. // Don't try to adjust the parameters if this layer doesn't have any.
if (params_grad.size() != 0) if (params_grad.size() != 0)
solvers.top()(details, static_cast<const tensor&>(params_grad)); solvers.top()(details, static_cast<const tensor&>(params_grad));
subnetwork.update(x, solvers.pop()); subnetwork->update(x, solvers.pop());
gradient_input_is_stale = true; gradient_input_is_stale = true;
} }
const subnet_type& subnet() const { return subnetwork; } const subnet_type& subnet() const { return *subnetwork; }
subnet_type& subnet() { return subnetwork; } subnet_type& subnet() { return *subnetwork; }
const layer_details_type& layer_details() const { return details; } const layer_details_type& layer_details() const { return details; }
layer_details_type& layer_details() { return details; } layer_details_type& layer_details() { return details; }
...@@ -746,14 +758,14 @@ namespace dlib ...@@ -746,14 +758,14 @@ namespace dlib
params_grad.clear(); params_grad.clear();
temp_tensor.clear(); temp_tensor.clear();
gradient_input_is_stale = true; gradient_input_is_stale = true;
subnetwork.clean(); subnetwork->clean();
} }
friend void serialize(const add_layer& item, std::ostream& out) friend void serialize(const add_layer& item, std::ostream& out)
{ {
int version = 1; int version = 1;
serialize(version, out); serialize(version, out);
serialize(item.subnetwork, out); serialize(*item.subnetwork, out);
serialize(item.details, out); serialize(item.details, out);
serialize(item.this_layer_setup_called, out); serialize(item.this_layer_setup_called, out);
serialize(item.gradient_input_is_stale, out); serialize(item.gradient_input_is_stale, out);
...@@ -768,7 +780,7 @@ namespace dlib ...@@ -768,7 +780,7 @@ namespace dlib
deserialize(version, in); deserialize(version, in);
if (version != 1) if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::add_layer."); throw serialization_error("Unexpected version found while deserializing dlib::add_layer.");
deserialize(item.subnetwork, in); deserialize(*item.subnetwork, in);
deserialize(item.details, in); deserialize(item.details, in);
deserialize(item.this_layer_setup_called, in); deserialize(item.this_layer_setup_called, in);
deserialize(item.gradient_input_is_stale, in); deserialize(item.gradient_input_is_stale, in);
...@@ -785,12 +797,12 @@ namespace dlib ...@@ -785,12 +797,12 @@ namespace dlib
// This layer can run in-place if it's an in-place capable layer and also if // This layer can run in-place if it's an in-place capable layer and also if
// the layer it's on top of doesn't need its own output tensor (since in-place // the layer it's on top of doesn't need its own output tensor (since in-place
// layers overwrite that tensor) // layers overwrite that tensor)
return impl::is_inplace_layer(details, subnetwork) && !subnetwork.this_layer_requires_forward_output(); return impl::is_inplace_layer(details, *subnetwork) && !subnetwork->this_layer_requires_forward_output();
} }
bool this_layer_requires_forward_output( bool this_layer_requires_forward_output(
) )
{ {
return impl::backward_requires_forward_output(details, subnetwork); return impl::backward_requires_forward_output(details, *subnetwork);
} }
void swap(add_layer& item) void swap(add_layer& item)
...@@ -806,7 +818,7 @@ namespace dlib ...@@ -806,7 +818,7 @@ namespace dlib
LAYER_DETAILS details; LAYER_DETAILS details;
subnet_type subnetwork; std::unique_ptr<subnet_type> subnetwork;
bool this_layer_setup_called; bool this_layer_setup_called;
bool gradient_input_is_stale; bool gradient_input_is_stale;
bool get_output_and_gradient_input_disabled; bool get_output_and_gradient_input_disabled;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment