Commit 0b235fe5 authored by Davis King's avatar Davis King

Added the repeat layer and generally optimized the code for really deep

networks.  This revolved mostly around removing really deep template recursions
since that upsets the compiler when you make really deep networks.
parent 7991275e
...@@ -368,68 +368,50 @@ namespace dlib ...@@ -368,68 +368,50 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename T, size_t N> template <typename T>
class sstack class sstack
{ {
public: public:
static_assert(N > 0, "You can't create an empty sstack.");
typedef T value_type; typedef T value_type;
const static size_t num_elements = N;
sstack() {}
sstack(const T& item_) : item(item_), data(item_) {}
const T& top() const { return item; } sstack() = delete;
T& top() { return item; }
size_t size() const { return N; } sstack (
T* data_,
size_t s
) : data(data_), mysize(s) {}
const sstack<T,N-1>& pop() const { return data; } const T& top() const
sstack<T,N-1>& pop() { return data; }
friend void serialize(const sstack& item, std::ostream& out)
{ {
serialize(item.top(), out); DLIB_CASSERT(size() != 0, "You can't call top() on an empty stack");
serialize(item.pop(), out); return *data;
} }
T& top()
friend void deserialize(sstack& item, std::istream& in)
{ {
deserialize(item.top(), in); DLIB_CASSERT(size() != 0, "You can't call top() on an empty stack");
deserialize(item.pop(), in); return *data;
} }
private: size_t size() const { return mysize; }
T item;
sstack<T,N-1> data;
};
template <typename T> sstack pop(size_t num=1)
class sstack<T,1> // base case of recursive definition.
{ {
public: DLIB_CASSERT(num < size(), "You can't pop more things from the stack than it has in it.");
sstack() {} return sstack(data+num, mysize-num);
sstack(const T& item_) : item(item_) {} }
const T& top() const { return item; }
T& top() { return item; }
size_t size() const { return 1; } private:
friend void serialize(const sstack& item, std::ostream& out) T* data;
{ size_t mysize;
serialize(item.top(), out); };
}
friend void deserialize(sstack& item, std::istream& in) template <typename T>
sstack<T> make_sstack(std::vector<T>& item)
{ {
deserialize(item.top(), in); return sstack<T>(item.data(), item.size());
} }
private:
T item;
};
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -568,6 +550,8 @@ namespace dlib ...@@ -568,6 +550,8 @@ namespace dlib
friend class add_tag_layer; friend class add_tag_layer;
template <template<typename> class T, typename U> template <template<typename> class T, typename U>
friend class add_skip_layer; friend class add_skip_layer;
template <size_t N, template<typename> class L, typename S>
friend class repeat;
// Allow copying networks from one to another as long as their corresponding // Allow copying networks from one to another as long as their corresponding
// layers can be constructed from each other. // layers can be constructed from each other.
...@@ -732,17 +716,18 @@ namespace dlib ...@@ -732,17 +716,18 @@ namespace dlib
} }
const tensor& get_final_data_gradient( const tensor& get_final_data_gradient(
) const { return subnetwork.get_final_data_gradient(); } ) const { return subnetwork->get_final_data_gradient(); }
template <typename solver_type> template <typename solver_type>
void update(const tensor& x, sstack<solver_type,num_layers>& solvers) void update(const tensor& x, sstack<solver_type> solvers)
{ {
update(x,private_get_gradient_input(),solvers); update(x,private_get_gradient_input(),solvers);
} }
template <typename solver_type> template <typename solver_type>
void update(const tensor& x, const tensor& gradient_input, sstack<solver_type,num_layers>& solvers) void update(const tensor& x, const tensor& gradient_input, sstack<solver_type> solvers)
{ {
DLIB_CASSERT(solvers.size()>=num_layers,"");
dimpl::subnet_wrapper<subnet_type> wsub(*subnetwork); dimpl::subnet_wrapper<subnet_type> wsub(*subnetwork);
params_grad.copy_size(details.get_layer_params()); params_grad.copy_size(details.get_layer_params());
impl::call_layer_backward(details, private_get_output(), impl::call_layer_backward(details, private_get_output(),
...@@ -881,6 +866,8 @@ namespace dlib ...@@ -881,6 +866,8 @@ namespace dlib
friend class add_tag_layer; friend class add_tag_layer;
template <template<typename> class T, typename U> template <template<typename> class T, typename U>
friend class add_skip_layer; friend class add_skip_layer;
template <size_t N, template<typename> class L, typename S>
friend class repeat;
// Allow copying networks from one to another as long as their corresponding // Allow copying networks from one to another as long as their corresponding
// layers can be constructed from each other. // layers can be constructed from each other.
...@@ -894,7 +881,8 @@ namespace dlib ...@@ -894,7 +881,8 @@ namespace dlib
gradient_input_is_stale(item.gradient_input_is_stale), gradient_input_is_stale(item.gradient_input_is_stale),
get_output_and_gradient_input_disabled(false), get_output_and_gradient_input_disabled(false),
x_grad(item.x_grad), x_grad(item.x_grad),
cached_output(item.cached_output) cached_output(item.cached_output),
grad_final(item.grad_final)
{ {
} }
...@@ -985,7 +973,7 @@ namespace dlib ...@@ -985,7 +973,7 @@ namespace dlib
const tensor& forward (const tensor& x) const tensor& forward (const tensor& x)
{ {
DLIB_CASSERT(x.num_samples()%sample_expansion_factor == 0,""); DLIB_CASSERT(x.num_samples()%sample_expansion_factor == 0,"");
subnet_wrapper wsub(x, grad_final_ignored); subnet_wrapper wsub(x, grad_final);
if (!this_layer_setup_called) if (!this_layer_setup_called)
{ {
details.setup(wsub); details.setup(wsub);
...@@ -1025,18 +1013,24 @@ namespace dlib ...@@ -1025,18 +1013,24 @@ namespace dlib
} }
const tensor& get_final_data_gradient( const tensor& get_final_data_gradient(
) const { return grad_final_ignored; } ) const { return grad_final; }
template <typename solver_type> template <typename solver_type>
void update(const tensor& x, sstack<solver_type,num_layers>& solvers) void update(const tensor& x, sstack<solver_type> solvers)
{ {
update(x,private_get_gradient_input(),solvers); update(x,private_get_gradient_input(),solvers);
} }
template <typename solver_type> template <typename solver_type>
void update(const tensor& x, const tensor& gradient_input, sstack<solver_type,num_layers>& solvers) void update(const tensor& x, const tensor& gradient_input, sstack<solver_type> solvers)
{ {
subnet_wrapper wsub(x, grad_final_ignored); DLIB_CASSERT(solvers.size()>=num_layers,"");
// make sure grad_final is initialized to 0
if (!have_same_dimensions(x, grad_final))
grad_final.copy_size(x);
grad_final = 0;
subnet_wrapper wsub(x, grad_final);
params_grad.copy_size(details.get_layer_params()); params_grad.copy_size(details.get_layer_params());
impl::call_layer_backward(details, private_get_output(), impl::call_layer_backward(details, private_get_output(),
gradient_input, wsub, static_cast<tensor&>(params_grad)); gradient_input, wsub, static_cast<tensor&>(params_grad));
...@@ -1055,7 +1049,7 @@ namespace dlib ...@@ -1055,7 +1049,7 @@ namespace dlib
void clean() void clean()
{ {
x_grad.clear(); x_grad.clear();
grad_final_ignored.clear(); grad_final.clear();
cached_output.clear(); cached_output.clear();
params_grad.clear(); params_grad.clear();
temp_tensor.clear(); temp_tensor.clear();
...@@ -1064,7 +1058,7 @@ namespace dlib ...@@ -1064,7 +1058,7 @@ namespace dlib
friend void serialize(const add_layer& item, std::ostream& out) friend void serialize(const add_layer& item, std::ostream& out)
{ {
int version = 1; int version = 2;
serialize(version, out); serialize(version, out);
serialize(item.input_layer, out); serialize(item.input_layer, out);
serialize(item.details, out); serialize(item.details, out);
...@@ -1073,13 +1067,14 @@ namespace dlib ...@@ -1073,13 +1067,14 @@ namespace dlib
serialize(item.get_output_and_gradient_input_disabled, out); serialize(item.get_output_and_gradient_input_disabled, out);
serialize(item.x_grad, out); serialize(item.x_grad, out);
serialize(item.cached_output, out); serialize(item.cached_output, out);
serialize(item.grad_final, out);
} }
friend void deserialize(add_layer& item, std::istream& in) friend void deserialize(add_layer& item, std::istream& in)
{ {
int version = 0; int version = 0;
deserialize(version, in); deserialize(version, in);
if (version != 1) if (version != 2)
throw serialization_error("Unexpected version found while deserializing dlib::add_layer."); throw serialization_error("Unexpected version found while deserializing dlib::add_layer.");
deserialize(item.input_layer, in); deserialize(item.input_layer, in);
deserialize(item.details, in); deserialize(item.details, in);
...@@ -1088,6 +1083,7 @@ namespace dlib ...@@ -1088,6 +1083,7 @@ namespace dlib
deserialize(item.get_output_and_gradient_input_disabled, in); deserialize(item.get_output_and_gradient_input_disabled, in);
deserialize(item.x_grad, in); deserialize(item.x_grad, in);
deserialize(item.cached_output, in); deserialize(item.cached_output, in);
deserialize(item.grad_final, in);
} }
private: private:
...@@ -1095,15 +1091,15 @@ namespace dlib ...@@ -1095,15 +1091,15 @@ namespace dlib
bool this_layer_requires_forward_output( bool this_layer_requires_forward_output(
) )
{ {
subnet_wrapper wsub(grad_final_ignored, grad_final_ignored); subnet_wrapper wsub(grad_final, grad_final);
return impl::backward_requires_forward_output(details, wsub); return impl::backward_requires_forward_output(details, wsub);
} }
class subnet_wrapper class subnet_wrapper
{ {
public: public:
subnet_wrapper(const tensor& x_, resizable_tensor& grad_final_ignored_) : subnet_wrapper(const tensor& x_, resizable_tensor& grad_final_) :
x(x_), grad_final_ignored(grad_final_ignored_) {} x(x_), grad_final(grad_final_) {}
subnet_wrapper(const subnet_wrapper&) = delete; subnet_wrapper(const subnet_wrapper&) = delete;
subnet_wrapper& operator=(const subnet_wrapper&) = delete; subnet_wrapper& operator=(const subnet_wrapper&) = delete;
...@@ -1111,21 +1107,17 @@ namespace dlib ...@@ -1111,21 +1107,17 @@ namespace dlib
const tensor& get_output() const { return x; } const tensor& get_output() const { return x; }
tensor& get_gradient_input() tensor& get_gradient_input()
{ {
// It doesn't matter what values are in this tensor but client code will if (!have_same_dimensions(x, grad_final))
// always assume it's the same dimension as the output so make sure that is
// the case. Note that we do set it to a non-crazy value though to avoid
// it being full of NaN and slowing the processing down.
if (!have_same_dimensions(x, grad_final_ignored))
{ {
grad_final_ignored.copy_size(x); grad_final.copy_size(x);
grad_final_ignored = 0; grad_final = 0;
} }
return grad_final_ignored; return grad_final;
} }
private: private:
const tensor& x; const tensor& x;
resizable_tensor& grad_final_ignored; resizable_tensor& grad_final;
}; };
void swap(add_layer& item) void swap(add_layer& item)
...@@ -1137,6 +1129,7 @@ namespace dlib ...@@ -1137,6 +1129,7 @@ namespace dlib
std::swap(get_output_and_gradient_input_disabled, item.get_output_and_gradient_input_disabled); std::swap(get_output_and_gradient_input_disabled, item.get_output_and_gradient_input_disabled);
std::swap(x_grad, item.x_grad); std::swap(x_grad, item.x_grad);
std::swap(cached_output, item.cached_output); std::swap(cached_output, item.cached_output);
std::swap(grad_final, item.grad_final);
} }
subnet_type input_layer; subnet_type input_layer;
...@@ -1146,13 +1139,13 @@ namespace dlib ...@@ -1146,13 +1139,13 @@ namespace dlib
bool get_output_and_gradient_input_disabled; bool get_output_and_gradient_input_disabled;
resizable_tensor x_grad; resizable_tensor x_grad;
resizable_tensor cached_output; resizable_tensor cached_output;
resizable_tensor grad_final;
// The following 3 objects don't logically contribute to the state of this class. // The following 2 objects don't logically contribute to the state of this class.
// They are only here to prevent them from being reallocated over and over in // They are only here to prevent them from being reallocated over and over in
// member functions. // member functions.
resizable_tensor params_grad; resizable_tensor params_grad;
resizable_tensor temp_tensor; resizable_tensor temp_tensor;
resizable_tensor grad_final_ignored;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -1167,7 +1160,7 @@ namespace dlib ...@@ -1167,7 +1160,7 @@ namespace dlib
public: public:
typedef SUBNET subnet_type; typedef SUBNET subnet_type;
typedef typename subnet_type::input_type input_type; typedef typename subnet_type::input_type input_type;
const static size_t num_layers = subnet_type::num_layers + 1; const static size_t num_layers = subnet_type::num_layers;
const static unsigned int sample_expansion_factor = subnet_type::sample_expansion_factor; const static unsigned int sample_expansion_factor = subnet_type::sample_expansion_factor;
static_assert(sample_expansion_factor >= 1, static_assert(sample_expansion_factor >= 1,
"The input layer can't produce fewer output tensors than there are inputs."); "The input layer can't produce fewer output tensors than there are inputs.");
...@@ -1232,15 +1225,15 @@ namespace dlib ...@@ -1232,15 +1225,15 @@ namespace dlib
) const { return subnetwork.get_final_data_gradient(); } ) const { return subnetwork.get_final_data_gradient(); }
template <typename solver_type> template <typename solver_type>
void update(const tensor& x, sstack<solver_type,num_layers>& solvers) void update(const tensor& x, sstack<solver_type> solvers)
{ {
subnetwork.update(x,solvers.pop()); subnetwork.update(x,solvers);
} }
template <typename solver_type> template <typename solver_type>
void update(const tensor& x, const tensor& gradient_input, sstack<solver_type,num_layers>& solvers) void update(const tensor& x, const tensor& gradient_input, sstack<solver_type> solvers)
{ {
subnetwork.update(x,gradient_input,solvers.pop()); subnetwork.update(x,gradient_input,solvers);
} }
const subnet_type& subnet() const { return subnetwork; } const subnet_type& subnet() const { return subnetwork; }
...@@ -1277,6 +1270,8 @@ namespace dlib ...@@ -1277,6 +1270,8 @@ namespace dlib
friend class add_tag_layer; friend class add_tag_layer;
template <template<typename> class T, typename U> template <template<typename> class T, typename U>
friend class add_skip_layer; friend class add_skip_layer;
template <size_t N, template<typename> class L, typename S>
friend class repeat;
// You woudln't put a tag on a layer if you didn't want to access its forward // You woudln't put a tag on a layer if you didn't want to access its forward
// outputs. So this is always true. // outputs. So this is always true.
...@@ -1302,6 +1297,268 @@ namespace dlib ...@@ -1302,6 +1297,268 @@ namespace dlib
subnet_type subnetwork; subnet_type subnetwork;
}; };
// ----------------------------------------------------------------------------------------
namespace impl
{
class repeat_input_layer
{
/*!
None of the declarations in this object are really used. The only reason it
exists is to allow the repeat object to use a special input layer in its
internal networks which will cause add_tag_layer objects that happen to be
right at the input to not create copies of their input tensors. So
introducing the repeat_input_layer object allows us to optimize the
implementation of add_tag_layer for a special case that arises when it's
used in the context of the repeat layer.
!*/
public:
typedef int input_type;
const static unsigned int sample_expansion_factor = 1;
template <typename input_iterator>
void to_tensor (
input_iterator ,
input_iterator ,
resizable_tensor&
) const
{
DLIB_CASSERT(false,"This function should never be called");
}
friend void serialize(const repeat_input_layer&, std::ostream&){}
friend void deserialize(repeat_input_layer&, std::istream&){}
};
}
template <
size_t num,
template<typename> class LAYER,
typename SUBNET
>
class repeat
{
static_assert(num > 0, "You can't have a layer repeated 0 times.");
public:
typedef SUBNET subnet_type;
typedef typename SUBNET::input_type input_type;
const static size_t num_layers = (LAYER<SUBNET>::num_layers-SUBNET::num_layers)*num + SUBNET::num_layers;
const static unsigned int sample_expansion_factor = SUBNET::sample_expansion_factor;
typedef LAYER<impl::repeat_input_layer> repeated_layer_type;
repeat(
) :
details(num)
{
}
size_t num_repetitions (
) const { return num; }
const repeated_layer_type& get_repeated_layer (
size_t i
) const
{
DLIB_CASSERT(i < num_repetitions(), "");
return details[i];
}
repeated_layer_type& get_repeated_layer (
size_t i
)
{
DLIB_CASSERT(i < num_repetitions(), "");
return details[i];
}
repeat(const repeat&) = default;
repeat(repeat&&) = default;
repeat& operator=(repeat&&) = default;
repeat& operator=(const repeat&) = default;
template <template<typename> class T, typename U>
repeat(
const repeat<num,T,U>& item
) :
subnetwork(item.subnetwork)
{
for (auto&& d : item.details)
details.emplace_back(d);
}
template <typename T, typename ...U>
repeat(
T arg1,
U ...args2
):
details(num, std::move(arg1)),
subnetwork(std::move(args2)...)
{
}
template <typename T, typename ...U>
repeat(
std::tuple<>,
T arg1,
U ...args2
):
details(num, std::move(arg1)),
subnetwork(std::move(args2)...)
{
}
template <typename input_iterator>
void to_tensor (
input_iterator ibegin,
input_iterator iend,
resizable_tensor& data
) const
{
subnetwork.to_tensor(ibegin,iend,data);
}
template <typename input_iterator>
const tensor& operator() (
input_iterator ibegin,
input_iterator iend
)
{
to_tensor(ibegin,iend,temp_tensor);
return forward(temp_tensor);
}
const tensor& operator() (const input_type& x)
{
return (*this)(&x, &x+1);
}
const tensor& forward(const tensor& x)
{
subnetwork.forward(x);
details[details.size()-1].forward(subnetwork.get_output());
for (long i = details.size()-2; i >= 0; --i)
details[i].forward(details[i+1].get_output());
return private_get_output();
}
private:
tensor& private_get_output() const
{
return details[0].private_get_output();
}
tensor& private_get_gradient_input()
{
return details[0].private_get_gradient_input();
}
public:
const tensor& get_output() const
{
return details[0].get_output();
}
tensor& get_gradient_input()
{
return details[0].get_gradient_input();
}
template <typename solver_type>
void update(const tensor& x, sstack<solver_type> solvers)
{
update(x,private_get_gradient_input(),solvers);
}
template <typename solver_type>
void update(const tensor& x, const tensor& gradient_input, sstack<solver_type> solvers)
{
const auto cnt = (LAYER<SUBNET>::num_layers-SUBNET::num_layers);
if (details.size() > 1)
{
details[0].update(details[1].get_output(), gradient_input, solvers);
for (size_t i = 1; i < details.size(); ++i)
{
if (i+1 < details.size())
details[i].update(details[i+1].get_output(), details[i-1].get_final_data_gradient(), solvers.pop(cnt*i));
else
details[i].update(subnetwork.get_output(), details[i-1].get_final_data_gradient(), solvers.pop(cnt*i));
}
}
else
{
details[0].update(subnetwork.get_output(), gradient_input, solvers);
}
subnetwork.update(x, details.back().get_final_data_gradient(), solvers.pop(cnt*details.size()));
}
const subnet_type& subnet() const { return subnetwork; }
subnet_type& subnet() { return subnetwork; }
void clean()
{
temp_tensor.clear();
subnetwork.clean();
for (auto&& d : details)
d.clean();
}
friend void serialize(const repeat& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.details, out);
serialize(item.subnetwork, out);
}
friend void deserialize(repeat& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::repeat.");
deserialize(item.details, in);
deserialize(item.subnetwork, in);
}
private:
template <typename T, typename U, typename E>
friend class add_layer;
template <typename T, bool is_first, typename E>
friend class dimpl::subnet_wrapper;
template <unsigned long T, typename U, typename E>
friend class add_tag_layer;
template <template<typename> class T, typename U>
friend class add_skip_layer;
template <size_t N, template<typename> class L, typename S>
friend class repeat;
bool this_layer_requires_forward_output(
)
{
return details[0].this_layer_requires_forward_output();
}
void disable_output_and_gradient_getters (
)
{
details[0].disable_output_and_gradient_getters();
}
std::vector<repeated_layer_type> details;
subnet_type subnetwork;
// temp_tensor doesn't logically contribute to the state of this class.
// It is here only to void needing to reallocate it over and over.
resizable_tensor temp_tensor;
};
template <
size_t num,
template<typename> class LAYER,
typename SUBNET
>
struct is_nonloss_layer_type<repeat<num,LAYER,SUBNET>> : std::true_type {};
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// This version of add_tag_layer handles the special case where the subnetwork being given // This version of add_tag_layer handles the special case where the subnetwork being given
...@@ -1312,12 +1569,13 @@ namespace dlib ...@@ -1312,12 +1569,13 @@ namespace dlib
public: public:
typedef INPUT_LAYER subnet_type; typedef INPUT_LAYER subnet_type;
typedef typename subnet_type::input_type input_type; typedef typename subnet_type::input_type input_type;
const static size_t num_layers = 1; const static size_t num_layers = 0;
const static unsigned int sample_expansion_factor = subnet_type::sample_expansion_factor; const static unsigned int sample_expansion_factor = subnet_type::sample_expansion_factor;
static_assert(sample_expansion_factor >= 1, static_assert(sample_expansion_factor >= 1,
"The input layer can't produce fewer output tensors than there are inputs."); "The input layer can't produce fewer output tensors than there are inputs.");
add_tag_layer() = default; add_tag_layer():gradient_input_is_stale(true),cached_output_ptr(nullptr) {}
add_tag_layer(const add_tag_layer&) = default; add_tag_layer(const add_tag_layer&) = default;
add_tag_layer& operator=(const add_tag_layer&) = default; add_tag_layer& operator=(const add_tag_layer&) = default;
add_tag_layer(add_tag_layer&& item) : add_tag_layer() { swap(item); } add_tag_layer(add_tag_layer&& item) : add_tag_layer() { swap(item); }
...@@ -1326,17 +1584,30 @@ namespace dlib ...@@ -1326,17 +1584,30 @@ namespace dlib
template <typename T, typename E> template <typename T, typename E>
add_tag_layer( add_tag_layer(
const add_tag_layer<ID,T,E>& item const add_tag_layer<ID,T,E>& item
) : input_layer(item.subnet()) ) : input_layer(item.subnet()),
cached_output(item.cached_output),
cached_output_ptr(nullptr),
grad_final(item.grad_final),
gradient_input_is_stale(item.gradient_input_is_stale)
{} {}
template <typename ...T> template <typename ...T>
add_tag_layer( add_tag_layer(
T ...args T ...args
) : ) :
input_layer(std::move(args)...) input_layer(std::move(args)...),
cached_output_ptr(nullptr),
gradient_input_is_stale(true)
{ {
} }
add_tag_layer (
std::tuple<>
) :
cached_output_ptr(nullptr),
gradient_input_is_stale(true)
{}
template <typename input_iterator> template <typename input_iterator>
void to_tensor ( void to_tensor (
input_iterator ibegin, input_iterator ibegin,
...@@ -1354,6 +1625,7 @@ namespace dlib ...@@ -1354,6 +1625,7 @@ namespace dlib
) )
{ {
input_layer.to_tensor(ibegin,iend,cached_output); input_layer.to_tensor(ibegin,iend,cached_output);
cached_output_ptr = nullptr;
return get_output(); return get_output();
} }
...@@ -1364,36 +1636,49 @@ namespace dlib ...@@ -1364,36 +1636,49 @@ namespace dlib
const tensor& forward(const tensor& x) const tensor& forward(const tensor& x)
{ {
// If this tag is the first layer in one of the sub networks inside a repeat
// layer then we don't want it to be creating copies of x. This is because, we
// can just hold a pointer to x since the way repeat is constructed guarantees
// that x will have a lifetime larger than this pointer.
if (is_same_type<INPUT_LAYER, impl::repeat_input_layer>::value)
cached_output_ptr = const_cast<tensor*>(&x);
else
cached_output = x; cached_output = x;
gradient_input_is_stale = true;
return get_output(); return get_output();
} }
const tensor& get_output() const const tensor& get_output() const
{ {
if (cached_output_ptr)
return *cached_output_ptr;
else
return cached_output; return cached_output;
} }
const tensor& get_final_data_gradient( const tensor& get_final_data_gradient(
) const { return grad_final_ignored; } ) const { return grad_final; }
tensor& get_gradient_input() tensor& get_gradient_input()
{ {
if (!have_same_dimensions(cached_output, grad_final_ignored)) if (!have_same_dimensions(get_output(), grad_final) ||
gradient_input_is_stale)
{ {
grad_final_ignored.copy_size(get_output()); grad_final.copy_size(get_output());
grad_final_ignored = 0; grad_final = 0;
gradient_input_is_stale = false;
} }
return grad_final_ignored; return grad_final;
} }
template <typename solver_type> template <typename solver_type>
void update(const tensor& /*x*/, sstack<solver_type,num_layers>& /*solvers*/) void update(const tensor& /*x*/, sstack<solver_type> /*solvers*/)
{ {
// nothing to update // nothing to update
} }
template <typename solver_type> template <typename solver_type>
void update(const tensor& /*x*/, const tensor& gradient_input, sstack<solver_type,num_layers>& /*solvers*/) void update(const tensor& /*x*/, const tensor& gradient_input, sstack<solver_type> /*solvers*/)
{ {
// nothing to update // nothing to update
} }
...@@ -1403,8 +1688,9 @@ namespace dlib ...@@ -1403,8 +1688,9 @@ namespace dlib
void clean() void clean()
{ {
grad_final_ignored.clear(); grad_final.clear();
cached_output.clear(); cached_output.clear();
cached_output_ptr = 0;
} }
friend void serialize(const add_tag_layer& item, std::ostream& out) friend void serialize(const add_tag_layer& item, std::ostream& out)
...@@ -1413,7 +1699,8 @@ namespace dlib ...@@ -1413,7 +1699,8 @@ namespace dlib
serialize(version, out); serialize(version, out);
serialize(item.input_layer, out); serialize(item.input_layer, out);
serialize(item.cached_output, out); serialize(item.cached_output, out);
serialize(item.grad_final_ignored, out); serialize(item.grad_final, out);
serialize(item.gradient_input_is_stale, out);
} }
friend void deserialize(add_tag_layer& item, std::istream& in) friend void deserialize(add_tag_layer& item, std::istream& in)
...@@ -1424,7 +1711,9 @@ namespace dlib ...@@ -1424,7 +1711,9 @@ namespace dlib
throw serialization_error("Unexpected version found while deserializing dlib::add_tag_layer."); throw serialization_error("Unexpected version found while deserializing dlib::add_tag_layer.");
deserialize(item.input_layer, in); deserialize(item.input_layer, in);
deserialize(item.cached_output, in); deserialize(item.cached_output, in);
deserialize(item.grad_final_ignored, in); deserialize(item.grad_final, in);
deserialize(item.gradient_input_is_stale, in);
item.cached_output_ptr = nullptr;
} }
private: private:
...@@ -1437,6 +1726,8 @@ namespace dlib ...@@ -1437,6 +1726,8 @@ namespace dlib
friend class add_tag_layer; friend class add_tag_layer;
template <template<typename> class T, typename U> template <template<typename> class T, typename U>
friend class add_skip_layer; friend class add_skip_layer;
template <size_t N, template<typename> class L, typename S>
friend class repeat;
// You woudln't put a tag on a layer if you didn't want to access its forward // You woudln't put a tag on a layer if you didn't want to access its forward
// outputs. So this is always true. // outputs. So this is always true.
...@@ -1455,7 +1746,7 @@ namespace dlib ...@@ -1455,7 +1746,7 @@ namespace dlib
} }
tensor& private_get_output() const tensor& private_get_output() const
{ return get_output(); } { return const_cast<tensor&>(get_output()); }
tensor& private_get_gradient_input() tensor& private_get_gradient_input()
{ return get_gradient_input(); } { return get_gradient_input(); }
...@@ -1463,12 +1754,16 @@ namespace dlib ...@@ -1463,12 +1754,16 @@ namespace dlib
{ {
std::swap(input_layer, item.input_layer); std::swap(input_layer, item.input_layer);
std::swap(cached_output, item.cached_output); std::swap(cached_output, item.cached_output);
std::swap(grad_final_ignored, item.grad_final_ignored); std::swap(cached_output_ptr, item.cached_output_ptr);
std::swap(grad_final, item.grad_final);
std::swap(gradient_input_is_stale, item.gradient_input_is_stale);
} }
subnet_type input_layer; subnet_type input_layer;
resizable_tensor cached_output; resizable_tensor cached_output;
resizable_tensor grad_final_ignored; tensor* cached_output_ptr;
resizable_tensor grad_final;
bool gradient_input_is_stale;
}; };
template <unsigned long ID, typename U, typename E> template <unsigned long ID, typename U, typename E>
...@@ -1653,7 +1948,7 @@ namespace dlib ...@@ -1653,7 +1948,7 @@ namespace dlib
double update ( double update (
const tensor& x, const tensor& x,
label_iterator lbegin, label_iterator lbegin,
sstack<solver_type,num_layers>& solvers sstack<solver_type> solvers
) )
{ {
subnetwork.forward(x); subnetwork.forward(x);
...@@ -1668,7 +1963,7 @@ namespace dlib ...@@ -1668,7 +1963,7 @@ namespace dlib
input_iterator ibegin, input_iterator ibegin,
input_iterator iend, input_iterator iend,
label_iterator lbegin, label_iterator lbegin,
sstack<solver_type,num_layers>& solvers sstack<solver_type> solvers
) )
{ {
to_tensor(ibegin,iend,temp_tensor); to_tensor(ibegin,iend,temp_tensor);
...@@ -1678,7 +1973,7 @@ namespace dlib ...@@ -1678,7 +1973,7 @@ namespace dlib
template <typename solver_type> template <typename solver_type>
double update ( double update (
const tensor& x, const tensor& x,
sstack<solver_type,num_layers>& solvers sstack<solver_type> solvers
) )
{ {
subnetwork.forward(x); subnetwork.forward(x);
...@@ -1692,7 +1987,7 @@ namespace dlib ...@@ -1692,7 +1987,7 @@ namespace dlib
double update ( double update (
input_iterator ibegin, input_iterator ibegin,
input_iterator iend, input_iterator iend,
sstack<solver_type,num_layers>& solvers std::vector<solver_type>& solvers
) )
{ {
to_tensor(ibegin,iend,temp_tensor); to_tensor(ibegin,iend,temp_tensor);
...@@ -1850,7 +2145,7 @@ namespace dlib ...@@ -1850,7 +2145,7 @@ namespace dlib
public: public:
typedef SUBNET subnet_type; typedef SUBNET subnet_type;
typedef typename subnet_type::input_type input_type; typedef typename subnet_type::input_type input_type;
const static size_t num_layers = subnet_type::num_layers + 1; const static size_t num_layers = subnet_type::num_layers;
const static unsigned int sample_expansion_factor = subnet_type::sample_expansion_factor; const static unsigned int sample_expansion_factor = subnet_type::sample_expansion_factor;
static_assert(sample_expansion_factor >= 1, static_assert(sample_expansion_factor >= 1,
"The input layer can't produce fewer output tensors than there are inputs."); "The input layer can't produce fewer output tensors than there are inputs.");
...@@ -1924,15 +2219,15 @@ namespace dlib ...@@ -1924,15 +2219,15 @@ namespace dlib
} }
template <typename solver_type> template <typename solver_type>
void update(const tensor& x, sstack<solver_type,num_layers>& solvers) void update(const tensor& x, sstack<solver_type> solvers)
{ {
subnetwork.update(x,solvers.pop()); subnetwork.update(x,solvers);
} }
template <typename solver_type> template <typename solver_type>
void update(const tensor& x, const tensor& gradient_input, sstack<solver_type,num_layers>& solvers) void update(const tensor& x, const tensor& gradient_input, sstack<solver_type> solvers)
{ {
subnetwork.update(x,gradient_input,solvers.pop()); subnetwork.update(x,gradient_input,solvers);
} }
const subnet_type& subnet() const const subnet_type& subnet() const
...@@ -1976,6 +2271,8 @@ namespace dlib ...@@ -1976,6 +2271,8 @@ namespace dlib
friend class add_tag_layer; friend class add_tag_layer;
template <template<typename> class T, typename U> template <template<typename> class T, typename U>
friend class add_skip_layer; friend class add_skip_layer;
template <size_t N, template<typename> class L, typename S>
friend class repeat;
bool this_layer_requires_forward_output( bool this_layer_requires_forward_output(
) { return layer<TAG_TYPE>(subnetwork).this_layer_requires_forward_output(); } ) { return layer<TAG_TYPE>(subnetwork).this_layer_requires_forward_output(); }
......
...@@ -69,48 +69,38 @@ namespace dlib ...@@ -69,48 +69,38 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
typename T, typename T
size_t N
> >
class sstack class sstack
{ {
/*! /*!
REQUIREMENTS ON T
- T is default and copy constructable.
REQUIREMENTS ON N
- N > 0
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This is a basic stack of T objects. It holds N of the objects and is This is a basic stack of T objects. It contains no data itself but simply
entirely allocated on the stack rather than on the heap. points to a memory range of T object and allows you to access that block of
T objects as a stack.
!*/ !*/
public: public:
typedef T value_type; typedef T value_type;
const static size_t num_elements = N;
sstack( sstack() = delete;
);
/*!
ensures
- #size() == N
- All elements of this stack are default constructed.
!*/
sstack( sstack (
const T& item T* data,
size_t s
); );
/*! /*!
ensures ensures
- #size() == N - #size() == s
- Initializes all N elements in this stack with the given item. E.g. - #top() == *data
top()==item, pop().top()==item, pop().pop().top()==item, etc. - #pop(i).top() == data[i]
!*/ !*/
const T& top( const T& top(
) const; ) const;
/*! /*!
requires
- size() != 0
ensures ensures
- returns the top element of the stack. - returns the top element of the stack.
!*/ !*/
...@@ -118,46 +108,41 @@ namespace dlib ...@@ -118,46 +108,41 @@ namespace dlib
T& top( T& top(
); );
/*! /*!
requires
- size() != 0
ensures ensures
- returns the top element of the stack. - returns the top element of the stack.
!*/ !*/
size_t size( size_t size(
) const; ) const;
/*! /*!
ensures ensures
- returns the number of elements in this stack. In particular, the number - returns the number of elements in this stack.
returned is always N.
!*/
const sstack<T,N-1>& pop(
) const;
/*!
requires
- size() > 1
ensures
- returns a reference to the sub-stack S such that:
- S.size() == size()-1.
- S.top() is the next element in the stack.
!*/ !*/
sstack<T,N-1>& pop( sstack pop(
size_t num = 1
); );
/*! /*!
requires requires
- size() > 1 - num < size()
ensures ensures
- returns a reference to the sub-stack S such that: - returns a reference to the sub-stack S such that:
- S.size() == size()-1. - S.size() == size()-num.
- S.top() is the next element in the stack. - S.top() is num elements down the stack.
!*/ !*/
}; };
void serialize(const sstack& item, std::ostream& out); template <
void deserialize(sstack& item, std::istream& in); typename T
>
sstack<T> make_sstack(
std::vector<T>& item
) { return sstack<T>(item.data(), item.size()); }
/*! /*!
provides serialization support ensures
- returns a sstack that sits on top of the given std::vector.
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -180,6 +165,7 @@ namespace dlib ...@@ -180,6 +165,7 @@ namespace dlib
- SUBNET is an add_layer object. - SUBNET is an add_layer object.
- SUBNET is an add_tag_layer object. - SUBNET is an add_tag_layer object.
- SUBNET is an add_skip_layer object. - SUBNET is an add_skip_layer object.
- SUBNET is a repeat object.
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This object represents a deep neural network. In particular, it is a tool This object represents a deep neural network. In particular, it is a tool
...@@ -406,7 +392,7 @@ namespace dlib ...@@ -406,7 +392,7 @@ namespace dlib
template <typename solver_type> template <typename solver_type>
void update( void update(
const tensor& x, const tensor& x,
sstack<solver_type,num_layers>& solvers sstack<solver_type> solvers
); );
/*! /*!
requires requires
...@@ -415,9 +401,10 @@ namespace dlib ...@@ -415,9 +401,10 @@ namespace dlib
subsequently modified in any way. subsequently modified in any way.
- get_gradient_input() has been set equal to the gradient of this network's - get_gradient_input() has been set equal to the gradient of this network's
output with respect to some loss function. output with respect to some loss function.
- This instance of solvers has only ever been used with this network. That - The given solvers have only ever been used with this network. That
is, if you want to call update() on some other neural network object then is, if you want to call update() on some other neural network object then
you must not reuse the same solvers object. you must NOT reuse the same solvers object.
- solvers.size() >= num_layers
ensures ensures
- Back propagates the error gradient, get_gradient_input(), through this - Back propagates the error gradient, get_gradient_input(), through this
network and uses the provided solvers to update the network parameters. network and uses the provided solvers to update the network parameters.
...@@ -431,7 +418,7 @@ namespace dlib ...@@ -431,7 +418,7 @@ namespace dlib
void update( void update(
const tensor& x, const tensor& x,
const tensor& gradient_input, const tensor& gradient_input,
sstack<solver_type,num_layers>& solvers sstack<solver_type> solvers
); );
/*! /*!
requires requires
...@@ -439,9 +426,10 @@ namespace dlib ...@@ -439,9 +426,10 @@ namespace dlib
Moreover, this was the most recent call to forward() and x has not been Moreover, this was the most recent call to forward() and x has not been
subsequently modified in any way. subsequently modified in any way.
- have_same_dimensions(gradient_input, get_output()) == true - have_same_dimensions(gradient_input, get_output()) == true
- This instance of solvers has only ever been used with this network. That - The given solvers have only ever been used with this network. That
is, if you want to call update() on some other neural network object then is, if you want to call update() on some other neural network object then
you must not reuse the same solvers object. you must NOT reuse the same solvers object.
- solvers.size() >= num_layers
ensures ensures
- This function is identical to the version of update() defined immediately - This function is identical to the version of update() defined immediately
above except that it back-propagates gradient_input through the network above except that it back-propagates gradient_input through the network
...@@ -504,6 +492,7 @@ namespace dlib ...@@ -504,6 +492,7 @@ namespace dlib
- SUBNET is an add_layer object. - SUBNET is an add_layer object.
- SUBNET is an add_tag_layer object. - SUBNET is an add_tag_layer object.
- SUBNET is an add_skip_layer object. - SUBNET is an add_skip_layer object.
- SUBNET is a repeat object.
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This object represents a deep neural network. In particular, it is a tool This object represents a deep neural network. In particular, it is a tool
...@@ -766,7 +755,7 @@ namespace dlib ...@@ -766,7 +755,7 @@ namespace dlib
double update ( double update (
const tensor& x, const tensor& x,
label_iterator lbegin, label_iterator lbegin,
sstack<solver_type,num_layers>& solvers sstack<solver_type> solvers
); );
/*! /*!
requires requires
...@@ -774,9 +763,10 @@ namespace dlib ...@@ -774,9 +763,10 @@ namespace dlib
- x.num_samples() > 0 - x.num_samples() > 0
- lbegin == iterator pointing to the start of a range of - lbegin == iterator pointing to the start of a range of
x.num_samples()/sample_expansion_factor label_type elements. x.num_samples()/sample_expansion_factor label_type elements.
- This instance of solvers has only ever been used with this network. That - The given solvers have only ever been used with this network. That
is, if you want to call update() on some other neural network object then is, if you want to call update() on some other neural network object then
you must not reuse the same solvers object. you must NOT reuse the same solvers object.
- solvers.size() >= num_layers
ensures ensures
- runs x through the network, compares the output to the expected output - runs x through the network, compares the output to the expected output
pointed to by lbegin, and updates the network parameters via pointed to by lbegin, and updates the network parameters via
...@@ -793,7 +783,7 @@ namespace dlib ...@@ -793,7 +783,7 @@ namespace dlib
input_iterator ibegin, input_iterator ibegin,
input_iterator iend, input_iterator iend,
label_iterator lbegin, label_iterator lbegin,
sstack<solver_type,num_layers>& solvers sstack<solver_type> solvers
); );
/*! /*!
requires requires
...@@ -801,9 +791,10 @@ namespace dlib ...@@ -801,9 +791,10 @@ namespace dlib
- std::distance(ibegin,iend) > 0 - std::distance(ibegin,iend) > 0
- lbegin == iterator pointing to the start of a range of - lbegin == iterator pointing to the start of a range of
std::distance(ibegin,iend) label_type elements. std::distance(ibegin,iend) label_type elements.
- This instance of solvers has only ever been used with this network. That - The given solvers have only ever been used with this network. That
is, if you want to call update() on some other neural network object then is, if you want to call update() on some other neural network object then
you must not reuse the same solvers object. you must NOT reuse the same solvers object.
- solvers.size() >= num_layers
ensures ensures
- runs [ibegin,iend) through the network, compares the output to the - runs [ibegin,iend) through the network, compares the output to the
expected output pointed to by lbegin, and updates the network parameters expected output pointed to by lbegin, and updates the network parameters
...@@ -820,16 +811,17 @@ namespace dlib ...@@ -820,16 +811,17 @@ namespace dlib
template <typename solver_type> template <typename solver_type>
double update ( double update (
const tensor& x, const tensor& x,
sstack<solver_type,num_layers>& solvers sstack<solver_type> solvers
); );
/*! /*!
requires requires
- LOSS_DETAILS is an unsupervised loss. i.e. label_type==no_label_type. - LOSS_DETAILS is an unsupervised loss. i.e. label_type==no_label_type.
- x.num_samples()%sample_expansion_factor == 0 - x.num_samples()%sample_expansion_factor == 0
- x.num_samples() > 0 - x.num_samples() > 0
- This instance of solvers has only ever been used with this network. That - The given solvers have only ever been used with this network. That
is, if you want to call update() on some other neural network object then is, if you want to call update() on some other neural network object then
you must not reuse the same solvers object. you must NOT reuse the same solvers object.
- solvers.size() >= num_layers
ensures ensures
- runs x through the network and updates the network parameters by - runs x through the network and updates the network parameters by
back-propagating the loss gradient through the network. back-propagating the loss gradient through the network.
...@@ -842,16 +834,17 @@ namespace dlib ...@@ -842,16 +834,17 @@ namespace dlib
double update ( double update (
input_iterator ibegin, input_iterator ibegin,
input_iterator iend, input_iterator iend,
sstack<solver_type,num_layers>& solvers sstack<solver_type> solvers
); );
/*! /*!
requires requires
- LOSS_DETAILS is an unsupervised loss. i.e. label_type==no_label_type. - LOSS_DETAILS is an unsupervised loss. i.e. label_type==no_label_type.
- [ibegin, iend) is an iterator range over input_type objects. - [ibegin, iend) is an iterator range over input_type objects.
- std::distance(ibegin,iend) > 0 - std::distance(ibegin,iend) > 0
- This instance of solvers has only ever been used with this network. That - The given solvers have only ever been used with this network. That
is, if you want to call update() on some other neural network object then is, if you want to call update() on some other neural network object then
you must not reuse the same solvers object. you must NOT reuse the same solvers object.
- solvers.size() >= num_layers
ensures ensures
- runs [ibegin,iend) through the network and updates the network parameters - runs [ibegin,iend) through the network and updates the network parameters
by back-propagating the loss gradient through the network. by back-propagating the loss gradient through the network.
...@@ -881,6 +874,115 @@ namespace dlib ...@@ -881,6 +874,115 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
size_t num,
template<typename> class LAYER,
typename SUBNET
>
class repeat
{
/*!
REQUIREMENTS ON num
- num > 0
REQUIREMENTS ON LAYER
- LAYER must be a template that stacks more layers onto a deep neural
network. For example, if net_type were a network without a loss layer,
then it should be legal to create a deeper network with a type of
LAYER<net_type>.
REQUIREMENTS ON SUBNET
- One of the following must be true:
- SUBNET is an add_layer object.
- SUBNET is an add_tag_layer object.
- SUBNET is an add_skip_layer object.
- SUBNET is a repeat object.
WHAT THIS OBJECT REPRESENTS
This object adds more layers to a deep neural network. In particular, it
adds LAYER on top of SUBNET num times. So for example, if num were 2 then
repeat<2,LAYER,SUBNET> would create a network equivalent to LAYER<LAYER<SUBNET>>.
Also, this object provides an interface identical to the one defined by the
add_layer object except that we add the num_repetitions() and
get_repeated_layer() methods. These additions are shown below along with
some additional explanatory comments.
!*/
public:
typedef SUBNET subnet_type;
typedef typename SUBNET::input_type input_type;
const static size_t num_layers = (LAYER<SUBNET>::num_layers-SUBNET::num_layers)*num + SUBNET::num_layers;
const static unsigned int sample_expansion_factor = SUBNET::sample_expansion_factor;
typedef LAYER<an_unspecified_input_type> repeated_layer_type;
template <typename T, typename ...U>
repeat(
T arg1,
U ...args2
);
/*!
ensures
- arg1 is used to initialize the num_repetitions() copies of LAYER inside
this object. That is, all the LAYER elements are initialized identically
by being given copies of arg1.
- The rest of the arguments to the constructor, i.e. args2, are passed to
SUBNET's constructor.
!*/
size_t num_repetitions (
) const;
/*!
ensures
- returns num (i.e. the number of times LAYER was stacked on top of SUBNET)
!*/
const repeated_layer_type& get_repeated_layer (
size_t i
) const;
/*!
requires
- i < num_repetitions()
ensures
- returns a reference to the i-th instance of LAYER. For example,
get_repeated_layer(0) returns the instance of LAYER that is on the top of
the network while get_repeated_layer(num_repetitions()-1) returns the
instance of LAYER that is stacked immediately on top of SUBNET.
!*/
repeated_layer_type& get_repeated_layer (
size_t i
);
/*!
requires
- i < num_repetitions()
ensures
- returns a reference to the i-th instance of LAYER. For example,
get_repeated_layer(0) returns the instance of LAYER that is on the top of
the network while get_repeated_layer(num_repetitions()-1) returns the
instance of LAYER that is stacked immediately on top of SUBNET.
!*/
const subnet_type& subnet(
) const;
/*!
ensures
- returns the SUBNET base network that repeat sits on top of. If you want
to access the LAYER components then you must use get_repeated_layer().
!*/
subnet_type& subnet(
);
/*!
ensures
- returns the SUBNET base network that repeat sits on top of. If you want
to access the LAYER components then you must use get_repeated_layer().
!*/
};
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -897,6 +999,7 @@ namespace dlib ...@@ -897,6 +999,7 @@ namespace dlib
- SUBNET is an add_layer object. - SUBNET is an add_layer object.
- SUBNET is an add_tag_layer object. - SUBNET is an add_tag_layer object.
- SUBNET is an add_skip_layer object. - SUBNET is an add_skip_layer object.
- SUBNET is a repeat object.
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This object adds a new layer to a deep neural network. However, this layer This object adds a new layer to a deep neural network. However, this layer
...@@ -942,6 +1045,7 @@ namespace dlib ...@@ -942,6 +1045,7 @@ namespace dlib
- SUBNET is an add_layer object. - SUBNET is an add_layer object.
- SUBNET is an add_tag_layer object. - SUBNET is an add_tag_layer object.
- SUBNET is an add_skip_layer object. - SUBNET is an add_skip_layer object.
- SUBNET is a repeat object.
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This object adds a new layer to a deep neural network which draws its This object adds a new layer to a deep neural network which draws its
......
...@@ -48,7 +48,7 @@ namespace dlib ...@@ -48,7 +48,7 @@ namespace dlib
dnn_trainer( dnn_trainer(
const net_type& net_, const net_type& net_,
const solver_type& solver_ const solver_type& solver_
) : job_pipe(0), net(net_), solvers(solver_) ) : job_pipe(0), net(net_), solvers(net_type::num_layers, solver_)
{ {
init(); init();
} }
...@@ -81,7 +81,7 @@ namespace dlib ...@@ -81,7 +81,7 @@ namespace dlib
) )
{ {
wait_for_thread_to_pause(); wait_for_thread_to_pause();
solvers = solver_; solvers = std::vector<solver_type>(net_type::num_layers, solver_);
} }
unsigned long get_mini_batch_size ( unsigned long get_mini_batch_size (
...@@ -119,14 +119,14 @@ namespace dlib ...@@ -119,14 +119,14 @@ namespace dlib
} }
const sstack<solver_type,net_type::num_layers>& get_solvers ( const std::vector<solver_type>& get_solvers (
) const ) const
{ {
wait_for_thread_to_pause(); wait_for_thread_to_pause();
return solvers; return solvers;
} }
sstack<solver_type,net_type::num_layers>& get_solvers ( std::vector<solver_type>& get_solvers (
) )
{ {
wait_for_thread_to_pause(); wait_for_thread_to_pause();
...@@ -260,7 +260,7 @@ namespace dlib ...@@ -260,7 +260,7 @@ namespace dlib
friend void serialize(const dnn_trainer& item, std::ostream& out) friend void serialize(const dnn_trainer& item, std::ostream& out)
{ {
item.wait_for_thread_to_pause(); item.wait_for_thread_to_pause();
int version = 1; int version = 2;
serialize(version, out); serialize(version, out);
serialize(item.rs, out); serialize(item.rs, out);
serialize(item.num_epochs, out); serialize(item.num_epochs, out);
...@@ -275,7 +275,7 @@ namespace dlib ...@@ -275,7 +275,7 @@ namespace dlib
item.wait_for_thread_to_pause(); item.wait_for_thread_to_pause();
int version = 0; int version = 0;
deserialize(version, in); deserialize(version, in);
if (version != 1) if (version != 2)
throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer."); throw serialization_error("Unexpected version found while deserializing dlib::dnn_trainer.");
deserialize(item.rs, in); deserialize(item.rs, in);
deserialize(item.num_epochs, in); deserialize(item.num_epochs, in);
...@@ -309,13 +309,13 @@ namespace dlib ...@@ -309,13 +309,13 @@ namespace dlib
template <typename T> template <typename T>
void run_update(job_t& next_job, const T&) void run_update(job_t& next_job, const T&)
{ {
rs.add(net.update(next_job.t, next_job.labels.begin(), solvers)); rs.add(net.update(next_job.t, next_job.labels.begin(), make_sstack(solvers)));
} }
void run_update(job_t& next_job, const no_label_type&) void run_update(job_t& next_job, const no_label_type&)
{ {
no_label_type pick_wich_run_update; no_label_type pick_wich_run_update;
rs.add(net.update(next_job.t, solvers)); rs.add(net.update(next_job.t, make_sstack(solvers)));
} }
void thread() void thread()
...@@ -361,7 +361,7 @@ namespace dlib ...@@ -361,7 +361,7 @@ namespace dlib
int cuda_device_id; int cuda_device_id;
net_type net; net_type net;
sstack<solver_type,net_type::num_layers> solvers; std::vector<solver_type> solvers;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -93,24 +93,30 @@ namespace dlib ...@@ -93,24 +93,30 @@ namespace dlib
assigned to each element in get_solvers(). assigned to each element in get_solvers().
!*/ !*/
const sstack<solver_type,net_type::num_layers>& get_solvers ( const std::vector<solver_type>& get_solvers (
) const; ) const;
/*! /*!
ensures ensures
- returns the solvers used to optimize each layer of the neural network - returns the solvers used to optimize each layer of the neural network
get_net(). In particular, the first layer's solver is get_net(). In particular, the first layer's solver is
get_solvers().top(), the second layer's solver is get_solvers()[0], the second layer's solver is
get_solvers().pop().top(), and so on. get_solvers()[1], and so on.
!*/ !*/
sstack<solver_type,net_type::num_layers>& get_solvers ( std::vector<solver_type>& get_solvers (
); );
/*! /*!
ensures ensures
- returns the solvers used to optimize each layer of the neural network - returns the solvers used to optimize each layer of the neural network
get_net(). In particular, the first layer's solver is get_net(). In particular, the first layer's solver is
get_solvers().top(), the second layer's solver is get_solvers()[0], the second layer's solver is
get_solvers().pop().top(), and so on. get_solvers()[1], and so on.
- It should be noted that you should never change the number of elements in
the vector returned by get_solvers() (i.e. don't do something that
changes get_solvers().size()). It will be set to net_type::num_layers by
this object and you should leave it at that. The non-const version of
get_solvers() is provided only so you can tweak the parameters of a
particular solver.
!*/ !*/
unsigned long get_mini_batch_size ( unsigned long get_mini_batch_size (
......
...@@ -974,8 +974,8 @@ namespace ...@@ -974,8 +974,8 @@ namespace
rcon_(6) rcon_(6)
); );
DLIB_TEST(layer<tag1>(net).num_layers == 9); DLIB_TEST(layer<tag1>(net).num_layers == 8);
DLIB_TEST(layer<skip1>(net).num_layers == 9+3+3+1); DLIB_TEST(layer<skip1>(net).num_layers == 8+3+3);
DLIB_TEST(&layer<skip1>(net).get_output() == &layer<tag1>(net).get_output()); DLIB_TEST(&layer<skip1>(net).get_output() == &layer<tag1>(net).get_output());
DLIB_TEST(&layer<skip1>(net).get_output() != &layer<tag1>(net).subnet().subnet().get_output()); DLIB_TEST(&layer<skip1>(net).get_output() != &layer<tag1>(net).subnet().subnet().get_output());
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment