Commit df19361c authored by Davis King's avatar Davis King

Made calling clean() on network objects also call clean on any layer detail

objects that also provide a .clean() method.
parent 115e8b6d
...@@ -56,6 +56,26 @@ namespace dlib ...@@ -56,6 +56,26 @@ namespace dlib
template <typename T> template <typename T>
double get_weight_decay_multiplier(const T& obj) { return impl::get_weight_decay_multiplier(obj, special_()); } double get_weight_decay_multiplier(const T& obj) { return impl::get_weight_decay_multiplier(obj, special_()); }
// ----------------------------------------------------------------------------------------
namespace impl
{
template <typename T, typename int_<decltype(&T::clean)>::type = 0>
void call_clean_method_if_exists (
T& obj,
special_
) { obj.clean(); }
template <typename T>
void call_clean_method_if_exists (T& , general_) {}
}
template <typename T>
void call_clean_method_if_exists(T& obj) { impl::call_clean_method_if_exists(obj, special_()); }
/*!
ensures
- calls obj.clean() if obj has a .clean() method.
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
namespace impl namespace impl
...@@ -893,6 +913,7 @@ namespace dlib ...@@ -893,6 +913,7 @@ namespace dlib
temp_tensor.clear(); temp_tensor.clear();
gradient_input_is_stale = true; gradient_input_is_stale = true;
subnetwork->clean(); subnetwork->clean();
call_clean_method_if_exists(details);
} }
friend void serialize(const add_layer& item, std::ostream& out) friend void serialize(const add_layer& item, std::ostream& out)
...@@ -1255,6 +1276,7 @@ namespace dlib ...@@ -1255,6 +1276,7 @@ namespace dlib
params_grad.clear(); params_grad.clear();
temp_tensor.clear(); temp_tensor.clear();
gradient_input_is_stale = true; gradient_input_is_stale = true;
call_clean_method_if_exists(details);
} }
friend void serialize(const add_layer& item, std::ostream& out) friend void serialize(const add_layer& item, std::ostream& out)
......
...@@ -562,6 +562,8 @@ namespace dlib ...@@ -562,6 +562,8 @@ namespace dlib
clean(). The purpose of clean() is to compact the network object prior clean(). The purpose of clean() is to compact the network object prior
to saving it to disk so that it takes up less space and the IO is to saving it to disk so that it takes up less space and the IO is
quicker. quicker.
- This also calls the .clean() method on any layer details objects that
define a .clean() method.
!*/ !*/
}; };
......
...@@ -1315,6 +1315,12 @@ namespace dlib ...@@ -1315,6 +1315,12 @@ namespace dlib
deserialize(item.mask, in); deserialize(item.mask, in);
} }
void clean(
)
{
mask.clear();
}
friend std::ostream& operator<<(std::ostream& out, const dropout_& item) friend std::ostream& operator<<(std::ostream& out, const dropout_& item)
{ {
out << "dropout\t (" out << "dropout\t ("
......
...@@ -358,6 +358,20 @@ namespace dlib ...@@ -358,6 +358,20 @@ namespace dlib
input_tensor_to_output_tensor(). input_tensor_to_output_tensor().
!*/ !*/
void clean (
);
/*!
Implementing this function is optional. If you don't need it then you don't
have to provide a clean(). But if you do provide it then it must behave as
follows:
ensures
- calling clean() Causes this object to forget about everything except its
parameters. This is useful if your layer caches information between
forward and backward passes and you want to clean out that cache
information before saving the network to disk.
!*/
}; };
std::ostream& operator<<(std::ostream& out, const EXAMPLE_COMPUTATIONAL_LAYER_& item); std::ostream& operator<<(std::ostream& out, const EXAMPLE_COMPUTATIONAL_LAYER_& item);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment