Commit 3597df5e authored by Davis King's avatar Davis King

Made add_layer hold subnetworks though a pointer so that most of a

network is allocated on the heap rather than resulting in really large
stack usage for large networks.
parent 72b250bb
......@@ -535,16 +535,28 @@ namespace dlib
add_layer(
):
subnetwork(new subnet_type()),
this_layer_setup_called(false),
gradient_input_is_stale(true),
get_output_and_gradient_input_disabled(false)
{
if (this_layer_operates_inplace())
subnetwork.disable_output_and_gradient_getters();
subnetwork->disable_output_and_gradient_getters();
}
add_layer(const add_layer&) = default;
add_layer& operator=(const add_layer&) = default;
add_layer(const add_layer& item)
{
details = item.details;
subnetwork.reset(new subnet_type(*item.subnetwork));
this_layer_setup_called = item.this_layer_setup_called;
gradient_input_is_stale = item.gradient_input_is_stale;
get_output_and_gradient_input_disabled = item.get_output_and_gradient_input_disabled;
x_grad = item.x_grad;
cached_output = item.cached_output;
params_grad = item.params_grad;
temp_tensor = item.temp_tensor;
}
add_layer& operator=(const add_layer& item) { add_layer(item).swap(*this); return *this;}
add_layer(add_layer&& item) : add_layer() { swap(item); }
add_layer& operator=(add_layer&& item) { swap(item); return *this; }
......@@ -563,7 +575,7 @@ namespace dlib
add_layer(
const add_layer<T,U,E>& item
) :
subnetwork(item.subnet()),
subnetwork(new subnet_type(item.subnet())),
details(item.layer_details()),
this_layer_setup_called(item.this_layer_setup_called),
gradient_input_is_stale(item.gradient_input_is_stale),
......@@ -572,7 +584,7 @@ namespace dlib
cached_output(item.cached_output)
{
if (this_layer_operates_inplace())
subnetwork.disable_output_and_gradient_getters();
subnetwork->disable_output_and_gradient_getters();
}
template <typename ...T>
......@@ -581,13 +593,13 @@ namespace dlib
T&& ...args
) :
details(layer_det),
subnetwork(std::forward<T>(args)...),
subnetwork(new subnet_type(std::forward<T>(args)...)),
this_layer_setup_called(false),
gradient_input_is_stale(true),
get_output_and_gradient_input_disabled(false)
{
if (this_layer_operates_inplace())
subnetwork.disable_output_and_gradient_getters();
subnetwork->disable_output_and_gradient_getters();
}
template <typename ...T>
......@@ -596,13 +608,13 @@ namespace dlib
T&& ...args
) :
details(std::move(layer_det)),
subnetwork(std::forward<T>(args)...),
subnetwork(new subnet_type(std::forward<T>(args)...)),
this_layer_setup_called(false),
gradient_input_is_stale(true),
get_output_and_gradient_input_disabled(false)
{
if (this_layer_operates_inplace())
subnetwork.disable_output_and_gradient_getters();
subnetwork->disable_output_and_gradient_getters();
}
template <typename ...T, typename LD, typename ...U>
......@@ -611,13 +623,13 @@ namespace dlib
T&& ...args
) :
details(tuple_head(layer_det)),
subnetwork(tuple_tail(layer_det),std::forward<T>(args)...),
subnetwork(new subnet_type(tuple_tail(layer_det),std::forward<T>(args)...)),
this_layer_setup_called(false),
gradient_input_is_stale(true),
get_output_and_gradient_input_disabled(false)
{
if (this_layer_operates_inplace())
subnetwork.disable_output_and_gradient_getters();
subnetwork->disable_output_and_gradient_getters();
}
template <typename ...T, typename LD, typename ...U>
......@@ -641,7 +653,7 @@ namespace dlib
resizable_tensor& data
) const
{
subnetwork.to_tensor(ibegin,iend,data);
subnetwork->to_tensor(ibegin,iend,data);
}
template <typename input_iterator>
......@@ -662,8 +674,8 @@ namespace dlib
const tensor& forward(const tensor& x)
{
subnetwork.forward(x);
const dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
subnetwork->forward(x);
const dimpl::subnet_wrapper<subnet_type> wsub(*subnetwork);
if (!this_layer_setup_called)
{
details.setup(wsub);
......@@ -682,7 +694,7 @@ namespace dlib
tensor& private_get_output() const
{
if (const_cast<add_layer&>(*this).this_layer_operates_inplace())
return subnetwork.private_get_output();
return subnetwork->private_get_output();
else
return const_cast<resizable_tensor&>(cached_output);
}
......@@ -690,7 +702,7 @@ namespace dlib
{
if (this_layer_operates_inplace())
{
return subnetwork.private_get_gradient_input();
return subnetwork->private_get_gradient_input();
}
else
{
......@@ -722,19 +734,19 @@ namespace dlib
template <typename solver_type>
void update(const tensor& x, sstack<solver_type,num_layers>& solvers)
{
dimpl::subnet_wrapper<subnet_type> wsub(subnetwork);
dimpl::subnet_wrapper<subnet_type> wsub(*subnetwork);
params_grad.copy_size(details.get_layer_params());
impl::call_layer_backward(details, private_get_output(),
private_get_gradient_input(), wsub, static_cast<tensor&>(params_grad));
// Don't try to adjust the parameters if this layer doesn't have any.
if (params_grad.size() != 0)
solvers.top()(details, static_cast<const tensor&>(params_grad));
subnetwork.update(x, solvers.pop());
subnetwork->update(x, solvers.pop());
gradient_input_is_stale = true;
}
const subnet_type& subnet() const { return subnetwork; }
subnet_type& subnet() { return subnetwork; }
const subnet_type& subnet() const { return *subnetwork; }
subnet_type& subnet() { return *subnetwork; }
const layer_details_type& layer_details() const { return details; }
layer_details_type& layer_details() { return details; }
......@@ -746,14 +758,14 @@ namespace dlib
params_grad.clear();
temp_tensor.clear();
gradient_input_is_stale = true;
subnetwork.clean();
subnetwork->clean();
}
friend void serialize(const add_layer& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.subnetwork, out);
serialize(*item.subnetwork, out);
serialize(item.details, out);
serialize(item.this_layer_setup_called, out);
serialize(item.gradient_input_is_stale, out);
......@@ -768,7 +780,7 @@ namespace dlib
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::add_layer.");
deserialize(item.subnetwork, in);
deserialize(*item.subnetwork, in);
deserialize(item.details, in);
deserialize(item.this_layer_setup_called, in);
deserialize(item.gradient_input_is_stale, in);
......@@ -785,12 +797,12 @@ namespace dlib
// This layer can run in-place if it's an in-place capable layer and also if
// the layer it's on top of doesn't need its own output tensor (since in-place
// layers overwrite that tensor)
return impl::is_inplace_layer(details, subnetwork) && !subnetwork.this_layer_requires_forward_output();
return impl::is_inplace_layer(details, *subnetwork) && !subnetwork->this_layer_requires_forward_output();
}
bool this_layer_requires_forward_output(
)
{
return impl::backward_requires_forward_output(details, subnetwork);
return impl::backward_requires_forward_output(details, *subnetwork);
}
void swap(add_layer& item)
......@@ -806,7 +818,7 @@ namespace dlib
LAYER_DETAILS details;
subnet_type subnetwork;
std::unique_ptr<subnet_type> subnetwork;
bool this_layer_setup_called;
bool gradient_input_is_stale;
bool get_output_and_gradient_input_disabled;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment