Commit f335ce4f authored by Davis King's avatar Davis King

Adding a rough initial version of a deep learning API.

parent 16ea6f11
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNn_
#define DLIB_DNn_
#include "dnn/tensor.h"
#include "dnn/input.h"
#include "dnn/layers.h"
#include "dnn/loss.h"
#include "dnn/core.h"
#include "dnn/solvers.h"
#endif // DLIB_DNn_
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNn_CORE_H_
#define DLIB_DNn_CORE_H_
#include "core_abstract.h"
#include "tensor.h"
#include "solvers.h"
#include <iterator>
#include <memory>
#include <type_traits>
#include <dlib/statistics.h>
#include <dlib/rand.h>
#include <utility>
namespace dlib
{
// ----------------------------------------------------------------------------------------
// Tell us if T is one of the special layer types (i.e. add_layer, add_loss, add_tag,
// or add_skip).
template <typename T> struct is_layer_type : std::false_type {};
template <typename T> struct is_loss_layer_type : std::false_type {};
// ----------------------------------------------------------------------------------------
inline void randomize_parameters (
tensor& params,
unsigned long num_inputs_and_outputs,
dlib::rand& rnd
)
{
float* data = params.host();
for (size_t i = 0; i < params.size(); ++i)
{
// Draw a random number to initialize the layer according to formula (16)
// from Understanding the difficulty of training deep feedforward neural
// networks by Xavier Glorot and Yoshua Bengio.
float val = 2*rnd.get_random_float()-1;
val *= std::sqrt(6.0/(num_inputs_and_outputs));
data[i] = val;
}
}
// ----------------------------------------------------------------------------------------
template <typename T, size_t N>
class sstack
{
public:
static_assert(N > 0, "You can't create an empty sstack.");
typedef T value_type;
const static size_t num_elements = N;
sstack() {}
sstack(const T& item_) : item(item_), data(item_) {}
const T& top() const { return item; }
T& top() { return item; }
size_t size() const { return N; }
const sstack<T,N-1>& pop() const { return data; }
sstack<T,N-1>& pop() { return data; }
private:
T item;
sstack<T,N-1> data;
};
template <typename T>
class sstack<T,1> // base case of recursive definition.
{
public:
sstack() {}
explicit sstack(const T& item_) : item(item_) {}
const T& top() const { return item; }
T& top() { return item; }
size_t size() const { return 1; }
private:
T item;
};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
namespace dimpl
{
template <typename T, typename enabled=void>
class sub_net_wrapper
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a tool that makes an add_layer or add_loss object
expose only the part of its interface defined by the SUB_NET
type in layers_abstract.h. This way, when we pass sub network
objects to the layer callbacks those callbacks won't be able to
interact with the sub networks in a way other than specified
by the SUB_NET interface spec.
!*/
public:
sub_net_wrapper(T& l_) {}
// Nothing here because in this case T is one of the input layer types
// that doesn't have anything in it.
};
template <typename T>
class sub_net_wrapper<T,typename std::enable_if<is_layer_type<T>::value>::type>
{
public:
typedef T wrapped_type;
const static size_t num_layers = T::num_layers;
sub_net_wrapper(T& l_) : l(l_),sub(l.sub_net()) {}
const tensor& get_output() const { return l.get_output(); }
tensor& get_gradient_input() { return l.get_gradient_input(); }
const sub_net_wrapper<typename T::sub_net_type>& sub_net() const { sub; }
sub_net_wrapper<typename T::sub_net_type>& sub_net() { sub; }
private:
T& l;
sub_net_wrapper<typename T::sub_net_type> sub;
};
}
template <typename LAYER_DETAILS, typename SUB_NET, typename enabled = void>
class add_layer;
template <typename T, typename U>
struct is_layer_type<add_layer<T,U>> : std::true_type {};
template <typename LAYER_DETAILS, typename SUB_NET>
class add_layer<LAYER_DETAILS,SUB_NET,
typename std::enable_if<is_layer_type<SUB_NET>::value>::type>
{
public:
typedef LAYER_DETAILS layer_details_type;
typedef SUB_NET sub_net_type;
typedef typename sub_net_type::input_type input_type;
const static size_t num_layers = sub_net_type::num_layers + 1;
const static unsigned int sample_expansion_factor = sub_net_type::sample_expansion_factor;
add_layer(
):
this_layer_setup_called(false),
gradient_input_is_stale(true)
{
}
add_layer(const add_layer&) = default;
add_layer(add_layer&&) = default;
add_layer& operator=(add_layer&&) = default;
add_layer& operator=(const add_layer&) = default;
template <typename T, typename U, typename E>
friend class add_layer;
// Allow copying networks from one to another as long as their corresponding
// layers can be constructed from each other.
template <typename T, typename U, typename E>
add_layer(
const add_layer<T,U,E>& item
) :
sub_network(item.sub_net()),
details(item.layer_details()),
this_layer_setup_called(item.this_layer_setup_called),
gradient_input_is_stale(item.gradient_input_is_stale),
x_grad(item.x_grad),
cached_output(item.cached_output)
{
}
template <typename ...T>
add_layer(
const LAYER_DETAILS& layer_det,
T&& ...args
) :
details(layer_det),
sub_network(std::forward<T>(args)...),
this_layer_setup_called(false),
gradient_input_is_stale(true)
{
}
template <typename ...T>
add_layer(
LAYER_DETAILS&& layer_det,
T&& ...args
) :
details(std::move(layer_det)),
sub_network(std::forward<T>(args)...),
this_layer_setup_called(false),
gradient_input_is_stale(true)
{
}
template <typename input_iterator>
void to_tensor (
input_iterator begin,
input_iterator end,
resizable_tensor& data
) const
{
sub_network.to_tensor(begin,end,data);
}
template <typename input_iterator>
const tensor& operator() (
input_iterator ibegin,
input_iterator iend
)
/*!
ensures
- runs [ibegin,iend) through the network and returns the results
!*/
{
to_tensor(ibegin,iend,temp_tensor);
return forward(temp_tensor);
}
const tensor& operator() (const input_type& x)
/*!
ensures
- runs a single x through the network and returns the output.
!*/
{
return (*this)(&x, &x+1);
}
const tensor& forward(const tensor& x)
{
sub_network.forward(x);
const dimpl::sub_net_wrapper<sub_net_type> wsub(sub_network);
if (!this_layer_setup_called)
{
details.setup(wsub);
this_layer_setup_called = true;
}
details.forward(wsub, cached_output);
gradient_input_is_stale = true;
return get_output();
}
const tensor& get_output() const { return cached_output; }
tensor& get_gradient_input()
{
if (gradient_input_is_stale)
{
gradient_input_is_stale = false;
x_grad.copy_size(get_output());
x_grad = 0;
}
return x_grad;
}
template <typename solver_type>
void update(const tensor& x, sstack<solver_type,num_layers>& solvers)
/*!
requires
- forward(x) was called to forward propagate x though the network.
- x.num_samples() == get_gradient_input().num_samples()
- get_gradient_input() == the gradient of the network with respect
to some loss.
!*/
{
dimpl::sub_net_wrapper<sub_net_type> wsub(sub_network);
params_grad.copy_size(details.get_layer_params());
params_grad = 0;
details.backward(get_gradient_input(), wsub, static_cast<tensor&>(params_grad));
// Don't try to adjust the parameters if this layer doesn't have any.
if (params_grad.size() != 0)
solvers.top()(details, static_cast<const tensor&>(params_grad));
sub_network.update(x, solvers.pop());
}
const sub_net_type& sub_net() const { return sub_network; }
sub_net_type& sub_net() { return sub_network; }
const layer_details_type& layer_details() const { return details; }
layer_details_type& layer_details() { return details; }
void clean()
{
x_grad.clear();
cached_output.clear();
params_grad.clear();
temp_tensor.clear();
gradient_input_is_stale = true;
sub_network.clean();
}
private:
sub_net_type sub_network;
LAYER_DETAILS details;
bool this_layer_setup_called;
bool gradient_input_is_stale;
resizable_tensor x_grad;
resizable_tensor cached_output;
// The following 2 objects don't logically contribute to the state of this class.
// They are only here to prevent them from being reallocated over and over in
// member functions.
resizable_tensor params_grad;
resizable_tensor temp_tensor;
};
// ----------------------------------------------------------------------------------------
template <typename LAYER_DETAILS, typename INPUT_LAYER, typename enabled>
class add_layer
{
public:
typedef LAYER_DETAILS layer_details_type;
typedef INPUT_LAYER sub_net_type;
typedef typename INPUT_LAYER::input_type input_type;
const static unsigned int sample_expansion_factor = INPUT_LAYER::sample_expansion_factor;
const static size_t num_layers = 1;
static_assert(sample_expansion_factor >= 1,
"The input layer can't produce fewer output tensors than there are inputs.");
add_layer(
):
this_layer_setup_called(false),
gradient_input_is_stale(true)
{}
add_layer(const add_layer&) = default;
add_layer(add_layer&&) = default;
add_layer& operator=(add_layer&&) = default;
add_layer& operator=(const add_layer&) = default;
template <typename T, typename U, typename E>
friend class add_layer;
// Allow copying networks from one to another as long as their corresponding
// layers can be constructed from each other.
template <typename T, typename U, typename E>
add_layer(
const add_layer<T,U,E>& item
):
input_layer(item.sub_net()),
details(item.layer_details()),
this_layer_setup_called(item.this_layer_setup_called),
gradient_input_is_stale(item.gradient_input_is_stale),
x_grad(item.x_grad),
cached_output(item.cached_output)
{
}
add_layer(
const LAYER_DETAILS& layer_det
) :
details(layer_det),
this_layer_setup_called(false),
gradient_input_is_stale(true)
{}
add_layer(
LAYER_DETAILS&& layer_det
) :
details(std::move(layer_det)),
this_layer_setup_called(false),
gradient_input_is_stale(true)
{}
add_layer(
LAYER_DETAILS layer_det,
INPUT_LAYER il
) :
details(layer_det),
input_layer(il),
this_layer_setup_called(false),
gradient_input_is_stale(true)
{}
template <typename input_iterator>
void to_tensor (
input_iterator begin,
input_iterator end,
resizable_tensor& data
) const
{
input_layer.to_tensor(begin, end, data);
// make sure the input layer's to_tensor() function is implemented properly.
DLIB_CASSERT(std::distance(begin,end)*sample_expansion_factor == data.num_samples(),"");
data.async_copy_to_device();
}
template <typename input_iterator>
const tensor& operator() (
input_iterator ibegin,
input_iterator iend
)
/*!
ensures
- runs [ibegin,iend) through the network and returns the results
!*/
{
to_tensor(ibegin,iend,temp_tensor);
return forward(temp_tensor);
}
const tensor& operator() (const input_type& x)
/*!
ensures
- runs a single x through the network and returns the output.
!*/
{
return (*this)(&x, &x+1);
}
const tensor& forward (const tensor& x)
/*!
requires
- x.num_samples() is a multiple of sample_expansion_factor.
!*/
{
DLIB_CASSERT(x.num_samples()%sample_expansion_factor == 0,"");
sub_net_wrapper wsub(x, grad_final_ignored);
if (!this_layer_setup_called)
{
details.setup(wsub);
this_layer_setup_called = true;
}
details.forward(wsub, cached_output);
gradient_input_is_stale = true;
return get_output();
}
const tensor& get_output() const { return cached_output; }
tensor& get_gradient_input()
{
if (gradient_input_is_stale)
{
gradient_input_is_stale = false;
x_grad.copy_size(get_output());
x_grad = 0;
}
return x_grad;
}
template <typename solver_type>
void update(const tensor& x, sstack<solver_type,num_layers>& solvers)
/*!
requires
- x.num_samples() is a multiple of sample_expansion_factor.
- forward(x) was called to forward propagate x though the network.
- x.num_samples() == get_gradient_input().num_samples()
!*/
{
sub_net_wrapper wsub(x, grad_final_ignored);
params_grad.copy_size(details.get_layer_params());
params_grad = 0;
details.backward(get_gradient_input(), wsub, static_cast<tensor&>(params_grad));
// Don't try to adjust the parameters if this layer doesn't have any.
if (params_grad.size() != 0)
solvers.top()(details, static_cast<const tensor&>(params_grad));
}
const sub_net_type& sub_net() const { return input_layer; }
sub_net_type& sub_net() { return input_layer; }
const layer_details_type& layer_details() const { return details; }
layer_details_type& layer_details() { return details; }
void clean()
{
x_grad.clear();
grad_final_ignored.clear();
cached_output.clear();
params_grad.clear();
temp_tensor.clear();
gradient_input_is_stale = true;
}
private:
class sub_net_wrapper
{
public:
sub_net_wrapper(const tensor& x_, resizable_tensor& grad_final_ignored_) :
x(x_), grad_final_ignored(grad_final_ignored_) {}
const tensor& get_output() const { return x; }
tensor& get_gradient_input()
{
// It doesn't matter what values are in this tensor but client code will
// always assume it's the same dimension as the output so make sure that is
// the case.
grad_final_ignored.copy_size(x);
return grad_final_ignored;
}
private:
const tensor& x;
resizable_tensor& grad_final_ignored;
};
sub_net_type input_layer;
LAYER_DETAILS details;
bool this_layer_setup_called;
bool gradient_input_is_stale;
resizable_tensor x_grad;
resizable_tensor cached_output;
// The following 3 objects don't logically contribute to the state of this class.
// They are only here to prevent them from being reallocated over and over in
// member functions.
resizable_tensor params_grad;
resizable_tensor temp_tensor;
resizable_tensor grad_final_ignored;
};
// ----------------------------------------------------------------------------------------
template <unsigned long ID, typename SUB_NET>
class add_tag
{
public:
typedef SUB_NET sub_net_type;
typedef typename sub_net_type::input_type input_type;
const static size_t num_layers = sub_net_type::num_layers + 1;
const static unsigned int sample_expansion_factor = sub_net_type::sample_expansion_factor;
static_assert(sample_expansion_factor >= 1,
"The input layer can't produce fewer output tensors than there are inputs.");
add_tag() = default;
add_tag(const add_tag&) = default;
add_tag(add_tag&&) = default;
add_tag& operator=(add_tag&&) = default;
add_tag& operator=(const add_tag&) = default;
template <typename T>
add_tag(
const add_tag<ID,T>& item
) : sub_network(item.sub_net())
{}
template <typename ...T>
add_tag(
T ...args
) :
sub_network(std::move(args)...)
{
}
template <typename input_iterator>
void to_tensor (
input_iterator begin,
input_iterator end,
resizable_tensor& data
) const
{
sub_network.to_tensor(begin,end,data);
}
template <typename input_iterator>
const tensor& operator() (
input_iterator ibegin,
input_iterator iend
)
{
return sub_network(ibegin,iend);
}
const tensor& operator() (const input_type& x)
{
return sub_network(x);
}
const tensor& forward(const tensor& x)
{
return sub_network.forward(x);
}
const tensor& get_output() const { return sub_network.get_output(); }
tensor& get_gradient_input()
{
return sub_network.get_gradient_input();
}
template <typename solver_type>
void update(const tensor& x, sstack<solver_type,num_layers>& solvers)
{
sub_network.update(x,solvers.pop());
}
const sub_net_type& sub_net() const { return sub_network; }
sub_net_type& sub_net() { return sub_network; }
void clean()
{
sub_network.clean();
}
private:
sub_net_type sub_network;
};
template <unsigned long ID, typename U>
struct is_layer_type<add_tag<ID,U>> : std::true_type {};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <typename LOSS_DETAILS, typename SUB_NET>
class add_loss;
class no_label_type
{
private:
// We don't want anyone making these no_label_type objects. They are here only to
// allow add_loss::label_type and dnn_trainer::label_type to exist which voids
// needing to overload add_loss and dnn_trainer for supervised an unsupervised
// losses. It also can be a type to use in template metaprogramming to indicate
// "no label". So here we make the constructor private with the exception that
// add_loss objects can make it (again, just to simplify add_loss's
// implementation).
no_label_type()=default;
template <typename LOSS_DETAILS, typename SUB_NET> friend class add_loss;
};
// ----------------------------------------------------------------------------------------
template <typename LOSS_DETAILS, typename SUB_NET>
class add_loss
{
template <typename T, typename enabled=void>
struct get_loss_layer_label_type
{
typedef no_label_type type;
};
template <typename T>
struct get_loss_layer_label_type<T,typename std::enable_if<sizeof(typename T::label_type)!=0>::type>
{
typedef typename T::label_type type;
};
public:
typedef LOSS_DETAILS loss_details_type;
typedef SUB_NET sub_net_type;
typedef typename sub_net_type::input_type input_type;
// Note that the loss layer doesn't count as an additional layer.
const static size_t num_layers = sub_net_type::num_layers;
const static unsigned int sample_expansion_factor = sub_net_type::sample_expansion_factor;
typedef typename get_loss_layer_label_type<LOSS_DETAILS>::type label_type;
static_assert(is_layer_type<SUB_NET>::value, "SUB_NET must be of type add_layer, add_skip, or add_tag.");
static_assert(sample_expansion_factor == LOSS_DETAILS::sample_expansion_factor,
"The loss layer and input layer must agree on the sample_expansion_factor.");
add_loss() = default;
add_loss(const add_loss&) = default;
add_loss(add_loss&&) = default;
add_loss& operator=(add_loss&&) = default;
add_loss& operator=(const add_loss&) = default;
template <typename T, typename U>
add_loss(
const add_loss<T,U>& item
) :
loss(item.loss_details()),
sub(item.sub_net())
{}
template <typename ...T>
add_loss(
const LOSS_DETAILS& layer_det,
T&& ...args
) :
loss(layer_det),
sub(std::forward<T>(args)...)
{
}
template <typename ...T>
add_loss(
LOSS_DETAILS&& layer_det,
T&& ...args
) :
loss(std::move(layer_det)),
sub(std::forward<T>(args)...)
{
}
template <typename ...T>
add_loss(
T ...args
) :
sub(std::move(args)...)
{
}
template <typename input_iterator, typename output_iterator>
void operator() (
input_iterator ibegin,
input_iterator iend,
output_iterator obegin
)
/*!
requires
- obegin == iterator pointing to the start of a range of distance(ibegin,iend)
elements.
ensures
- runs [ibegin,iend) through the network and writes the output to the range at obegin.
!*/
{
sub.to_tensor(ibegin,iend,temp_tensor);
sub.forward(temp_tensor);
loss.to_label(sub, obegin);
}
const label_type& operator() (const input_type& x)
/*!
ensures
- runs a single x through the network and returns the output.
!*/
{
(*this)(&x, &x+1, &temp_label);
return temp_label;
}
template <typename input_iterator, typename label_iterator>
double compute_loss (
input_iterator ibegin,
input_iterator iend,
label_iterator lbegin
)
{
sub.to_tensor(ibegin,iend,temp_tensor);
sub.forward(temp_tensor);
dimpl::sub_net_wrapper<sub_net_type> wsub(sub);
return loss.compute_loss(temp_tensor, lbegin, wsub);
}
template <typename input_iterator>
double compute_loss (
input_iterator ibegin,
input_iterator iend
)
{
sub.to_tensor(ibegin,iend,temp_tensor);
sub.forward(temp_tensor);
dimpl::sub_net_wrapper<sub_net_type> wsub(sub);
return loss.compute_loss(temp_tensor, wsub);
}
template <typename input_iterator, typename label_iterator, typename solver_type>
double update (
input_iterator ibegin,
input_iterator iend,
label_iterator lbegin,
sstack<solver_type,num_layers>& solvers
)
{
sub.to_tensor(ibegin,iend,temp_tensor);
sub.forward(temp_tensor);
dimpl::sub_net_wrapper<sub_net_type> wsub(sub);
double l = loss.compute_loss(temp_tensor, lbegin, wsub);
sub.update(temp_tensor, solvers);
return l;
}
template <typename input_iterator, typename solver_type>
double update (
input_iterator ibegin,
input_iterator iend,
sstack<solver_type,num_layers>& solvers
)
{
sub.to_tensor(ibegin,iend,temp_tensor);
sub.forward(temp_tensor);
dimpl::sub_net_wrapper<sub_net_type> wsub(sub);
double l = loss.compute_loss(temp_tensor, wsub);
sub.update(temp_tensor, solvers);
return l;
}
const sub_net_type& sub_net() const { return sub; }
sub_net_type& sub_net() { return sub; }
const loss_details_type& loss_details() const { return loss; }
loss_details_type& loss_details() { return loss; }
void clean (
)
/*!
ensures
- Causes the network to forget about everything but its parameters.
That is, for each layer we will have:
- get_output().num_samples() == 0
- get_gradient_input().num_samples() == 0
However, running new input data though this network will still have the
same output it would have had regardless of any calls to clean().
Finally, the purpose of clean() is to compact the network object prior to
saving it to disk so that it takes up less space and the IO is quicker.
!*/
{
temp_tensor.clear();
sub.clear();
}
private:
loss_details_type loss;
sub_net_type sub;
// These two objects don't logically contribute to the state of this object. They
// are here to prevent them from being reallocated over and over.
label_type temp_label;
resizable_tensor temp_tensor;
};
template <typename T, typename U>
struct is_loss_layer_type<add_loss<T,U>> : std::true_type {};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
namespace impl
{
template <unsigned int i, typename T>
struct layer_helper
{
static T& makeT();
using next_type = typename std::remove_reference<decltype(makeT().sub_net())>::type;
using type = typename layer_helper<i-1,next_type>::type;
static type& layer(T& n)
{
return layer_helper<i-1,next_type>::layer(n.sub_net());
}
};
template <typename T>
struct layer_helper<0,T>
{
using type = T;
static type& layer(T& n)
{
return n;
}
};
template <template<typename> class Match, typename T, unsigned int i, typename enabled = void>
struct layer_helper_match
{
static T& makeT();
using next_type = typename std::remove_reference<decltype(makeT().sub_net())>::type;
using type = typename layer_helper_match<Match,next_type,i>::type;
static type& layer(T& n)
{
return layer_helper_match<Match,next_type,i>::layer(n.sub_net());
}
};
// This overload catches add_layer and add_loss templates.
template <template<typename> class Match, typename T, unsigned int i>
struct layer_helper_match<Match,T,i,
typename std::enable_if<std::is_same<const T,const Match<typename T::sub_net_type>>::value>::type>
{
using type = typename layer_helper<i,T>::type;
static type& layer(T& n)
{
return layer_helper<i,T>::layer(n);
}
};
// This overload catches input templates.
template <template<typename> class Match, typename T, unsigned int i>
struct layer_helper_match<Match,T,i,
typename std::enable_if<std::is_same<const T,const Match<typename T::input_type>>::value>::type>
{
using type = typename layer_helper<i,T>::type;
static type& layer(T& n)
{
return layer_helper<i,T>::layer(n);
}
};
// This overload catches sub_net_wrapper templates.
template <template<typename> class Match, typename T, unsigned int i>
struct layer_helper_match<Match,T,i,
typename std::enable_if<std::is_same<const typename T::wrapped_type,
const Match<typename T::wrapped_type::sub_net_type>>::value>::type>
{
using type = typename layer_helper<i,T>::type;
static type& layer(T& n)
{
return layer_helper<i,T>::layer(n);
}
};
}
template <unsigned int i, typename T>
typename impl::layer_helper<i,T>::type& layer (T& n)
{
return impl::layer_helper<i,T>::layer(n);
}
template <template<typename> class Match, typename T>
typename impl::layer_helper_match<Match,T,0>::type& layer (T& n)
{
return impl::layer_helper_match<Match,T,0>::layer(n);
}
template <template<typename> class Match, unsigned int i, typename T>
typename impl::layer_helper_match<Match,T,i>::type& layer (T& n)
{
return impl::layer_helper_match<Match,T,i>::layer(n);
}
// ----------------------------------------------------------------------------------------
template <template<typename> class TAG_TYPE, typename SUB_NET>
class add_skip
{
/*!
WHAT THIS OBJECT REPRESENTS
This object draws its inputs from layer<TAG_TYPE>(SUB_NET())
and performs the identity transform.
!*/
public:
typedef SUB_NET sub_net_type;
typedef typename sub_net_type::input_type input_type;
const static size_t num_layers = sub_net_type::num_layers + 1;
const static unsigned int sample_expansion_factor = sub_net_type::sample_expansion_factor;
static_assert(sample_expansion_factor >= 1,
"The input layer can't produce fewer output tensors than there are inputs.");
add_skip() = default;
add_skip(const add_skip&) = default;
add_skip(add_skip&&) = default;
add_skip& operator=(add_skip&&) = default;
add_skip& operator=(const add_skip&) = default;
template <typename T>
add_skip(
const add_skip<TAG_TYPE,T>& item
) : sub_network(item.sub_net())
{}
template <typename ...T>
add_skip(
T ...args
) :
sub_network(std::move(args)...)
{
}
template <typename input_iterator>
void to_tensor (
input_iterator begin,
input_iterator end,
resizable_tensor& data
) const
{
sub_network.to_tensor(begin,end,data);
}
template <typename input_iterator>
const tensor& operator() (
input_iterator ibegin,
input_iterator iend
)
{
sub_network(ibegin,iend);
return layer<TAG_TYPE>(sub_network).get_output();
}
const tensor& operator() (const input_type& x)
{
sub_network(x);
return layer<TAG_TYPE>(sub_network).get_output();
}
const tensor& forward(const tensor& x)
{
sub_network.forward(x);
return layer<TAG_TYPE>(sub_network).get_output();
}
const tensor& get_output() const
{
return layer<TAG_TYPE>(sub_network).get_output();
}
tensor& get_gradient_input()
{
return layer<TAG_TYPE>(sub_network).get_gradient_input();
}
template <typename solver_type>
void update(const tensor& x, sstack<solver_type,num_layers>& solvers)
{
sub_network.update(x,solvers.pop());
}
const sub_net_type& sub_net() const
{
return sub_network;
}
sub_net_type& sub_net()
{
return sub_network;
}
void clean()
{
sub_network.clean();
}
private:
sub_net_type sub_network;
};
template <template<typename> class T, typename U>
struct is_layer_type<add_skip<T,U>> : std::true_type {};
template <typename SUB_NET> using tag1 = add_tag< 1, SUB_NET>;
template <typename SUB_NET> using tag2 = add_tag< 2, SUB_NET>;
template <typename SUB_NET> using tag3 = add_tag< 3, SUB_NET>;
template <typename SUB_NET> using tag4 = add_tag< 4, SUB_NET>;
template <typename SUB_NET> using tag5 = add_tag< 5, SUB_NET>;
template <typename SUB_NET> using tag6 = add_tag< 6, SUB_NET>;
template <typename SUB_NET> using tag7 = add_tag< 7, SUB_NET>;
template <typename SUB_NET> using tag8 = add_tag< 8, SUB_NET>;
template <typename SUB_NET> using tag9 = add_tag< 9, SUB_NET>;
template <typename SUB_NET> using tag10 = add_tag<10, SUB_NET>;
template <typename SUB_NET> using skip1 = add_skip< tag1, SUB_NET>;
template <typename SUB_NET> using skip2 = add_skip< tag2, SUB_NET>;
template <typename SUB_NET> using skip3 = add_skip< tag3, SUB_NET>;
template <typename SUB_NET> using skip4 = add_skip< tag4, SUB_NET>;
template <typename SUB_NET> using skip5 = add_skip< tag5, SUB_NET>;
template <typename SUB_NET> using skip6 = add_skip< tag6, SUB_NET>;
template <typename SUB_NET> using skip7 = add_skip< tag7, SUB_NET>;
template <typename SUB_NET> using skip8 = add_skip< tag8, SUB_NET>;
template <typename SUB_NET> using skip9 = add_skip< tag9, SUB_NET>;
template <typename SUB_NET> using skip10 = add_skip<tag10, SUB_NET>;
// ----------------------------------------------------------------------------------------
namespace timpl
{
void fill_with_gassuan_random_numbers (
tensor& t,
dlib::rand& rnd,
double sigma = 1
)
{
float* data = t.host();
for (size_t i = 0; i < t.size(); ++i)
data[i] = rnd.get_random_gaussian()*sigma;
}
class test_layer_sub_net
{
public:
test_layer_sub_net (
dlib::rand& rnd_
) : rnd(rnd_)
{
// Output and gradient_input have to have the same dimensions in each
// layer.
const long num_samples = rnd.get_random_32bit_number()%4+3;
const long nr = rnd.get_random_32bit_number()%4+2;
const long nc = rnd.get_random_32bit_number()%4+2;
const long k = rnd.get_random_32bit_number()%4+2;
output.set_size(num_samples, nr, nc, k);
gradient_input.set_size(num_samples, nr, nc, k);
// Use a non-zero initial gradient to make sure the layers add to it
// rather than assign and blow away the initial value.
fill_with_gassuan_random_numbers(gradient_input, rnd, 0.01);
fill_with_gassuan_random_numbers(output, rnd);
}
const tensor& get_output() const { return output; }
const test_layer_sub_net& sub_net() const { init_sub(); return *sub; }
tensor& get_gradient_input() { return gradient_input; }
test_layer_sub_net& sub_net() { init_sub(); return *sub; }
unsigned long count_outputs() const
{
if (sub)
return sub->count_outputs() + output.size();
else
return output.size();
}
float& get_output_element(unsigned long i)
{
if (i < output.size())
return output.host()[i];
else
return sub_net().get_output_element(i-output.size());
}
float get_gradient_input_element(unsigned long i) const
{
if (i < gradient_input.size())
return gradient_input.host()[i];
else
return sub_net().get_gradient_input_element(i-gradient_input.size());
}
private:
// We lazily initialize sub-layers as needed when someone tries to call
// sub_net()
void init_sub() const
{
if (!sub)
sub.reset(new test_layer_sub_net(rnd));
}
dlib::rand& rnd;
mutable std::unique_ptr<test_layer_sub_net> sub;
resizable_tensor output;
resizable_tensor gradient_input;
};
void print_tensor(
const tensor& a
)
{
auto data = a.host();
for (size_t i = 0; i < a.size(); ++i)
std::cout << data[i] << " ";
std::cout << std::endl;
}
}
template <
typename layer_details_type
>
void test_layer (
layer_details_type l
)
{
const float base_eps = 0.01;
using namespace timpl;
// Do some setup
dlib::rand rnd;
test_layer_sub_net sub(rnd);
resizable_tensor output, out2, out3;
// Run setup() and forward() as well to make sure any calls to sub_net() have
// happened before we start assuming we know how many data elements there are
// (since we do a lazy layer creation thing based on calls to sub_net() inside
// test_layer_sub_net).
l.setup(sub);
l.forward(sub, output);
resizable_tensor input_grad;
input_grad.copy_size(output);
std::cout << "output.num_samples(): "<< output.num_samples() << std::endl;
fill_with_gassuan_random_numbers(input_grad, rnd);
// The f() we are computing gradients of is this thing. It's value at the current
// parameter and data values is:
std::cout << "f(data,params): " << dot(output, input_grad) << std::endl;
// We are going to save a copy of the sub.get_gradient_input() data before we do
// backpropagation since the backward() function is supposed to *add* to the
// gradients rather than overwrite them. We will use this saved data to check if
// that is the case.
const unsigned long num_data_inputs = sub.count_outputs();
std::vector<float> initial_gradient_input(num_data_inputs);
for (unsigned long i = 0; i < num_data_inputs; ++i)
initial_gradient_input[i] = sub.get_gradient_input_element(i);
// Now tell the layer to compute all the gradients. In the rest of this function
// we will just be checking that these gradients were computed correctly by
// comparing them to a central differences approximation.
resizable_tensor params_grad, random_noise;
params_grad.copy_size(l.get_layer_params());
random_noise.copy_size(l.get_layer_params());
randomize_parameters(random_noise, 5, rnd);
params_grad = random_noise;
l.backward(input_grad, sub, params_grad);
running_stats<double> rs_param, rs_data;
// ==================================================================
// first validate the way the parameter gradients are computed
for (long i = 0; i < params_grad.size(); ++i)
{
layer_details_type l1(l);
float eps = l1.get_layer_params().host()[i]*base_eps;
if (eps == 0)
eps = base_eps;
const float oldval = l1.get_layer_params().host()[i];
l1.get_layer_params().host()[i] = oldval+eps;
l1.forward(sub, out2);
l1.get_layer_params().host()[i] = oldval-eps;
l1.forward(sub, out3);
// Compute a reference derivative via a central differences approximation and
// compare it to the one output by the layer and make sure they match.
double reference_derivative = (dot(out2,input_grad)-dot(out3, input_grad))/(2*eps);
double output_derivative = params_grad.host()[i]-random_noise.host()[i];
double relative_error = (reference_derivative - output_derivative)/(reference_derivative + 1e-100);
if (std::abs(relative_error) > 0.01)
{
using namespace std;
cout << "PARAM ERROR: "<< relative_error << endl;
cout << " reference_derivative: " << reference_derivative << endl;
cout << " output_derivative: " << output_derivative << endl;
}
rs_param.add(std::abs(relative_error));
}
// ==================================================================
// now validate the data gradients
for (unsigned long i = 0; i < num_data_inputs; ++i)
{
const float oldval = sub.get_output_element(i);
float eps = oldval*base_eps;
if (eps == 0)
eps = base_eps;
sub.get_output_element(i) = oldval+eps;
l.forward(sub, out2);
sub.get_output_element(i) = oldval-eps;
l.forward(sub, out3);
// Compute a reference derivative via a central differences approximation and
// compare it to the one output by the layer and make sure they match.
double reference_derivative = (dot(out2,input_grad)-dot(out3, input_grad))/(2*eps);
double output_derivative = sub.get_gradient_input_element(i)-initial_gradient_input[i];
double relative_error = (reference_derivative - output_derivative)/(reference_derivative + 1e-100);
if (std::abs(relative_error) > 0.01)
{
using namespace std;
cout << "DATA ERROR: "<< relative_error << endl;
cout << " reference_derivative: " << reference_derivative << endl;
cout << " output_derivative: " << output_derivative << endl;
}
rs_data.add(std::abs(relative_error));
}
using namespace std;
if (rs_param.current_n() > 1)
{
cout << "rs_param.mean(): " << rs_param.mean() << endl;
cout << "rs_param.stddev(): " << rs_param.stddev() << endl;
cout << "rs_param.max(): " << rs_param.max() << endl;
}
if (rs_data.current_n() > 1)
{
cout << "rs_data.mean(): " << rs_data.mean() << endl;
cout << "rs_data.stddev(): " << rs_data.stddev() << endl;
cout << "rs_data.max(): " << rs_data.max() << endl;
}
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename net_type,
typename solver_type = sgd
>
class dnn_trainer
{
public:
static_assert(is_loss_layer_type<net_type>::value,
"The last layer in a network must be a loss layer.");
typedef typename net_type::label_type label_type;
typedef typename net_type::input_type input_type;
dnn_trainer()
{}
explicit dnn_trainer(const net_type& net_) : net(net_) {}
dnn_trainer(
const net_type& net_,
const solver_type& solver_
) : net(net_), solvers(solver_) {}
const net_type& get_net (
) const { return net; }
void set_net (
const net_type& net_
)
{
return net = net_;
}
void set_solver (
const solver_type& solver_
)
{
solvers = solver_;
}
const sstack<solver_type,net_type::num_layers>& get_solvers (
) const { return solvers; }
sstack<solver_type,net_type::num_layers>& get_solvers (
) { return solvers; }
const net_type& train (
const std::vector<input_type>& data,
const std::vector<label_type>& labels
)
/*!
requires
- data.size() == labels.size()
!*/
{
const int batch_size = 11;
for (int iter = 0; iter < 300; ++iter)
{
for (unsigned long i = 0; i < data.size(); i+=batch_size)
{
// TODO, move the contents of update() here and do the alternating tensor
// loading thing to hide GPU transfer latency.
std::cout << "loss: "<<net.update(data.begin()+i,
data.begin()+std::min(i+batch_size,i+data.size()-1),
labels.begin()+i,
solvers) << std::endl;
}
}
return net;
}
const net_type& train (
const std::vector<input_type>& data
)
/*!
ensures
- trains an auto-encoder
!*/
{
const bool has_unsupervised_loss = std::is_same<no_label_type, label_type>::value;
static_assert(has_unsupervised_loss,
"You can only call this version of train() when using an unsupervised loss.");
const int batch_size = 10;
for (int iter = 0; iter < 300; ++iter)
{
for (unsigned long i = 0; i < data.size(); i+=batch_size)
{
// TODO, move the contents of update() here and do the alternating tensor
// loading thing to hide GPU transfer latency.
std::cout << "loss: "<<net.update(data.begin()+i,
data.begin()+std::min(i+batch_size,i+data.size()-1),
solvers) << std::endl;
}
}
return net;
}
private:
net_type net;
sstack<solver_type,net_type::num_layers> solvers;
};
// TODO, make dnn_trainer serializable.
// ----------------------------------------------------------------------------------------
}
#endif // #define DLIB_DNn_CORE_H_
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_DNn_CORE_ABSTRACT_H_
#ifdef DLIB_DNn_CORE_ABSTRACT_H_
#include "tensor_abstract.h"
#include "solvers_abstract.h"
#include <memory>
#include <type_traits>
#include <dlib/statistics.h>
#include <dlib/rand.h>
namespace dlib
{
// ----------------------------------------------------------------------------------------
void randomize_parameters (
tensor& params,
unsigned long num_inputs_and_outputs,
dlib::rand& rnd
);
/*!
ensures
- This function assigns random values into params based on the given random
number generator. In particular, it uses the parameter initialization method
of formula 16 from the paper "Understanding the difficulty of training deep
feedforward neural networks" by Xavier Glorot and Yoshua Bengio.
- It is assumed that the total number of inputs and outputs from the layer is
num_inputs_and_outputs. That is, you should set num_inputs_and_outputs to
the sum of the dimensionalities of the vectors going into and out of the
layer that uses params as its parameters.
!*/
// ----------------------------------------------------------------------------------------
template <
typename T,
size_t N
>
class sstack
{
/*!
REQUIREMENTS ON T
- T is default and copy constructable.
REQUIREMENTS ON N
- N > 0
WHAT THIS OBJECT REPRESENTS
This is a basic stack of T objects. It holds N of the objects and is
entirely allocated on the stack.
!*/
public:
typedef T value_type;
const static size_t num_elements = N;
sstack(
);
/*!
ensures
- #size() == N
- All elements of this stack are default constructed.
!*/
sstack(
const T& item
);
/*!
ensures
- #size() == N
- Initializes all N elements in this stack with the given item.
E.g. top()==item, pop().top()==item, pop().pop().top()==item, etc.
!*/
const T& top(
) const;
/*!
ensures
- returns the top element of the stack.
!*/
T& top(
);
/*!
ensures
- returns the top element of the stack.
!*/
size_t size(
) const;
/*!
ensures
- returns the number of elements in this stack. In particular, the
number returned is always N.
!*/
const sstack<T,N-1>& pop(
) const;
/*!
requires
- size() > 1
ensures
- returns a reference to the sub-stack S such that:
- S.size() == size()-1.
- S.top() is the next element in the stack.
!*/
sstack<T,N-1>& pop(
);
/*!
requires
- size() > 1
ensures
- returns a reference to the sub-stack S such that:
- S.size() == size()-1.
- S.top() is the next element in the stack.
!*/
};
// ----------------------------------------------------------------------------------------
template <
typename LAYER_DETAILS,
typename SUB_NET
>
class add_layer
{
/*!
REQUIREMENTS ON LAYER_DETAILS
- Must be a type that implements the EXAMPLE_LAYER_ interface defined in
layers_abstract.h
REQUIREMENTS ON SUB_NET
- One of the following must be true:
- SUB_NET implements the input interface (TODO clarify) defined in
input_abstract.h.
- SUB_NET is an add_layer object.
- SUB_NET is an add_tag object.
- SUB_NET is an add_skip object.
WHAT THIS OBJECT REPRESENTS
Stacks a new layer, defined by LAYER_DETAILS, on top of SUB_NET type.
!*/
public:
typedef LAYER_DETAILS layer_details_type;
typedef SUB_NET sub_net_type;
typedef typename sub_net_type::input_type input_type;
const static unsigned int sample_expansion_factor = sub_net_type::sample_expansion_factor;
// If SUB_NET is an input layer then num_layers == 1, otherwise it has the
// definition shown here:
const static size_t num_layers = sub_net_type::num_layers + 1;
add_layer(
);
add_layer(const add_layer&) = default;
add_layer(add_layer&&) = default;
add_layer& operator=(add_layer&&) = default;
add_layer& operator=(const add_layer&) = default;
// Allow copying networks from one to another as long as their corresponding
// layers can be constructed from each other.
template <typename T, typename U>
add_layer(
const add_layer<T,U>& item
);
/*!
ensures
- #layer_details() == layer_details_type(item.layer_details())
- #sub_net() == sub_net_type(item.sub_net())
!*/
template <typename ...T>
add_layer(
const LAYER_DETAILS& layer_det,
T&& ...args
);
/*!
ensures
- #layer_details() == layer_details_type(layer_det)
- #sub_net() == sub_net_type(args)
!*/
template <typename ...T>
add_layer(
LAYER_DETAILS&& layer_det,
T&& ...args
);
/*!
ensures
- #layer_details() == layer_details_type(layer_det)
- #sub_net() == sub_net_type(args)
!*/
template <typename input_iterator>
void to_tensor (
input_iterator begin,
input_iterator end,
resizable_tensor& data
) const;
/*!
requires
- [begin, end) is an iterator range over input_type objects.
ensures
- Converts the iterator range into a tensor and stores it into #data.
- #data.num_samples() == distance(begin,end)*sample_expansion_factor.
- Invokes data.async_copy_to_device() so that the data begins transferring
to the device.
- Ultimately this function just calls sub_net().sub_net()...sub_net().to_tensor(begin,end,data).
!*/
template <typename input_iterator>
const tensor& operator() (
input_iterator ibegin,
input_iterator iend
);
/*!
ensures
- runs [ibegin,iend) through the network and returns the results.
In particular, this function performs:
to_tensor(ibegin,iend,temp_tensor);
return forward(temp_tensor);
- The return value from this function is also available in #get_output().
- have_same_dimensions(#get_gradient_input(), #get_output()) == true
- All elements of #get_gradient_input() are set to 0.
!*/
const tensor& operator() (
const input_type& x
);
/*!
ensures
- runs a single x through the network and returns the output.
I.e. returns (*this)(&x, &x+1);
!*/
const tensor& forward(
const tensor& x
);
/*!
ensures
- Runs x through the network and returns the results. In particular, this
function performs the equivalent of:
sub_net().forward(x);
if (this is the first time forward() has been called) then
layer_details().setup(sub_net());
layer_details().forward(sub_net(), get_output());
- The return value from this function is also available in #get_output().
- have_same_dimensions(#get_gradient_input(), #get_output()) == true
- All elements of #get_gradient_input() are set to 0.
!*/
{
sub_network.forward(x);
const dimpl::sub_net_wrapper<sub_net_type> wsub(sub_network);
if (!this_layer_setup_called)
{
details.setup(wsub);
this_layer_setup_called = true;
}
details.forward(wsub, cached_output);
gradient_input_is_stale = true;
return get_output();
}
const tensor& get_output(
) const;
/*!
ensures
- returns the output for the last tensor that was run through the network.
If nothing has been run through the network yet then returns an empty
tensor.
!*/
tensor& get_gradient_input(
);
/*!
ensures
-
!*/
{
if (gradient_input_is_stale)
{
gradient_input_is_stale = false;
x_grad.copy_size(get_output());
x_grad = 0;
}
return x_grad;
}
template <typename solver_type>
void update(const tensor& x, sstack<solver_type,num_layers>& solvers)
/*!
requires
- forward(x) was called to forward propagate x though the network.
- x.num_samples() == get_gradient_input().num_samples()
- get_gradient_input() == the gradient of the network with respect
to some loss.
!*/
{
dimpl::sub_net_wrapper<sub_net_type> wsub(sub_network);
params_grad.copy_size(details.get_layer_params());
params_grad = 0;
details.backward(get_gradient_input(), wsub, static_cast<tensor&>(params_grad));
// Don't try to adjust the parameters if this layer doesn't have any.
if (params_grad.size() != 0)
solvers.top()(details, static_cast<const tensor&>(params_grad));
sub_network.update(x, solvers.pop());
}
const sub_net_type& sub_net() const { return sub_network; }
sub_net_type& sub_net() { return sub_network; }
const layer_details_type& layer_details() const { return details; }
layer_details_type& layer_details() { return details; }
void clean(
);
};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class no_label_type;
template <
typename LOSS_DETAILS,
typename SUB_NET
>
class add_loss
{
/*!
REQUIREMENTS ON LOSS_DETAILS
- Must be a type that implements the EXAMPLE_LAYER_ interface defined in
layers_abstract.h
- LOSS_DETAILS::sample_expansion_factor == SUB_NET::sample_expansion_factor
i.e. The loss layer and input layer must agree on the sample_expansion_factor.
REQUIREMENTS ON SUB_NET
- One of the following must be true:
- SUB_NET is an add_layer object.
- SUB_NET is an add_tag object.
- SUB_NET is an add_skip object.
WHAT THIS OBJECT REPRESENTS
- Adds a loss layer, defined by LOSS_DETAILS, on top of SUB_NET.
!*/
public:
typedef LOSS_DETAILS loss_details_type;
typedef SUB_NET sub_net_type;
typedef typename sub_net_type::input_type input_type;
// Note that the loss layer doesn't count as an additional layer.
const static size_t num_layers = sub_net_type::num_layers;
const static unsigned int sample_expansion_factor = sub_net_type::sample_expansion_factor;
// If LOSS_DETAILS is an unsupervised loss then label_type==no_label_type.
// Otherwise it is defined as follows:
typedef typename LOSS_DETAILS::label_type label_type;
static_assert(sample_expansion_factor == LOSS_DETAILS::sample_expansion_factor,
"The loss layer and input layer must agree on the sample_expansion_factor.");
add_loss() = default;
add_loss(const add_loss&) = default;
add_loss(add_loss&&) = default;
add_loss& operator=(add_loss&&) = default;
add_loss& operator=(const add_loss&) = default;
template <typename T, typename U>
add_loss(
const add_loss<T,U>& item
) :
loss(item.loss_details()),
sub(item.sub_net())
{}
template <typename ...T>
add_loss(
const LOSS_DETAILS& layer_det,
T&& ...args
) :
loss(layer_det),
sub(std::forward<T>(args)...)
{
}
template <typename ...T>
add_loss(
LOSS_DETAILS&& layer_det,
T&& ...args
) :
loss(std::move(layer_det)),
sub(std::forward<T>(args)...)
{
}
template <typename ...T>
add_loss(
T ...args
) :
sub(std::move(args)...)
{
}
template <typename input_iterator, typename output_iterator>
void operator() (
input_iterator ibegin,
input_iterator iend,
output_iterator obegin
)
/*!
requires
- obegin == iterator pointing to the start of a range of distance(ibegin,iend)
elements.
ensures
- runs [ibegin,iend) through the network and writes the output to the range at obegin.
!*/
{
sub.to_tensor(ibegin,iend,temp_tensor);
sub.forward(temp_tensor);
loss.to_label(sub, obegin);
}
const label_type& operator() (const input_type& x)
/*!
ensures
- runs a single x through the network and returns the output.
!*/
{
(*this)(&x, &x+1, &temp_label);
return temp_label;
}
template <typename input_iterator, typename label_iterator>
double compute_loss (
input_iterator ibegin,
input_iterator iend,
label_iterator lbegin
)
{
sub.to_tensor(ibegin,iend,temp_tensor);
sub.forward(temp_tensor);
dimpl::sub_net_wrapper<sub_net_type> wsub(sub);
return loss.compute_loss(temp_tensor, lbegin, wsub);
}
template <typename input_iterator>
double compute_loss (
input_iterator ibegin,
input_iterator iend,
);
template <typename input_iterator, typename label_iterator, typename solver_type>
double update (
input_iterator ibegin,
input_iterator iend,
label_iterator lbegin,
sstack<solver_type,num_layers>& solvers
)
{
sub.to_tensor(ibegin,iend,temp_tensor);
sub.forward(temp_tensor);
dimpl::sub_net_wrapper<sub_net_type> wsub(sub);
double l = loss.compute_loss(temp_tensor, lbegin, wsub);
sub.update(temp_tensor, solvers);
return l;
}
template <typename input_iterator, typename solver_type>
double update (
input_iterator ibegin,
input_iterator iend,
sstack<solver_type,num_layers>& solvers
);
const sub_net_type& sub_net() const { return sub; }
sub_net_type& sub_net() { return sub; }
const loss_details_type& loss_details() const { return loss; }
loss_details_type& loss_details() { return loss; }
void clean (
)
/*!
ensures
- Causes the network to forget about everything but its parameters.
That is, for each layer we will have:
- get_output().num_samples() == 0
- get_gradient_input().num_samples() == 0
However, running new input data though this network will still have the
same output it would have had regardless of any calls to clean().
Finally, the purpose of clean() is to compact the network object prior to
saving it to disk so that it takes up less space and the IO is quicker.
!*/
{
temp_tensor.clear();
sub.clear();
}
private:
loss_details_type loss;
sub_net_type sub;
// These two objects don't logically contribute to the state of this object. They
// are here to prevent them from being reallocated over and over.
label_type temp_label;
resizable_tensor temp_tensor;
};
template <typename T, typename U>
struct is_layer_type<add_loss<T,U>> : std::true_type {};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
unsigned int i,
typename net_type
>
auto& layer (
net_type& n
);
/*!
requires
- net_type is an object of type add_layer, add_loss, add_skip, or add_tag.
ensures
- This function chains together i calls to n.sub_net() and returns the
result. So for example:
- if (i == 0)
- returns n
- else if (i == 1)
- returns n.sub_net()
- else if (i == 2)
- returns n.sub_net().sub_net()
- else if (i == 3)
- returns n.sub_net().sub_net().sub_net()
- else
- etc.
!*/
template <
template<typename> class Match,
typename net_type
>
auto& layer (
net_type& n
);
/*!
requires
- net_type is an object of type add_layer, add_loss, add_skip, or add_tag.
ensures
- returns the first layer in n that is of type Match. E.g. if net_type is
fc<relu<fc<input<sample_type>>>> then calling layer<relu>(n) would return
layer<1>(n), that is, a reference to the relu layer.
!*/
template <
template<typename> class Match,
unsigned int i,
typename net_type
>
auto& layer (
net_type& n
);
/*!
requires
- net_type is an object of type add_layer, add_loss, add_skip, or add_tag.
ensures
- returns layer<i>(layer<Match>(n))
!*/
// ----------------------------------------------------------------------------------------
template <
unsigned long ID,
typename SUB_NET
>
class add_tag
{
/*!
WHAT THIS OBJECT REPRESENTS
This object draws its inputs from sub_net() and performs the identity
transform. This means it is a no-op and its presence does not change
the behavior of the network. It exists solely to be used by add_skip
to reference a particular part of a network.
!*/
};
template <typename SUB_NET> using tag1 = add_tag< 1, SUB_NET>;
template <typename SUB_NET> using tag2 = add_tag< 2, SUB_NET>;
template <typename SUB_NET> using tag3 = add_tag< 3, SUB_NET>;
template <typename SUB_NET> using tag4 = add_tag< 4, SUB_NET>;
template <typename SUB_NET> using tag5 = add_tag< 5, SUB_NET>;
template <typename SUB_NET> using tag6 = add_tag< 6, SUB_NET>;
template <typename SUB_NET> using tag7 = add_tag< 7, SUB_NET>;
template <typename SUB_NET> using tag8 = add_tag< 8, SUB_NET>;
template <typename SUB_NET> using tag9 = add_tag< 9, SUB_NET>;
template <typename SUB_NET> using tag10 = add_tag<10, SUB_NET>;
// ----------------------------------------------------------------------------------------
template <
template<typename> class TAG_TYPE,
typename SUB_NET
>
class add_skip
{
/*!
WHAT THIS OBJECT REPRESENTS
This object draws its inputs from layer<TAG_TYPE>(sub_net())
and performs the identity transform.
!*/
};
template <typename SUB_NET> using skip1 = add_skip< tag1, SUB_NET>;
template <typename SUB_NET> using skip2 = add_skip< tag2, SUB_NET>;
template <typename SUB_NET> using skip3 = add_skip< tag3, SUB_NET>;
template <typename SUB_NET> using skip4 = add_skip< tag4, SUB_NET>;
template <typename SUB_NET> using skip5 = add_skip< tag5, SUB_NET>;
template <typename SUB_NET> using skip6 = add_skip< tag6, SUB_NET>;
template <typename SUB_NET> using skip7 = add_skip< tag7, SUB_NET>;
template <typename SUB_NET> using skip8 = add_skip< tag8, SUB_NET>;
template <typename SUB_NET> using skip9 = add_skip< tag9, SUB_NET>;
template <typename SUB_NET> using skip10 = add_skip<tag10, SUB_NET>;
// ----------------------------------------------------------------------------------------
template <
typename layer_details_type
>
void test_layer (
layer_details_type l
);
/*!
requires
- l implements the EXAMPLE_LAYER_ interface defined in layers_abstract.h
ensures
- tests l for compliance against the EXAMPLE_LAYER_ interface spec.
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename net_type,
typename solver_type = sgd
>
class dnn_trainer
{
/*!
REQUIREMENTS ON net_type
- net_type is an add_loss object.
REQUIREMENTS ON solver_type
- solver_type is an implementation of the EXAMPLE_SOLVER interface defined
in solvers_abstract.h
WHAT THIS OBJECT REPRESENTS
!*/
public:
typedef typename net_type::label_type label_type;
typedef typename net_type::input_type input_type;
dnn_trainer(
);
explicit dnn_trainer(
const net_type& net
);
dnn_trainer(
const net_type& net,
const solver_type& solver
);
const net_type& get_net (
) const;
void set_net (
const net_type& net
);
void set_solver (
const solver_type& solver_
);
const sstack<solver_type,net_type::num_layers>& get_solvers (
) const;
sstack<solver_type,net_type::num_layers>& get_solvers (
);
const net_type& train (
const std::vector<input_type>& data,
const std::vector<label_type>& labels
);
/*!
requires
- data.size() == labels.size()
- TODO: the net has a supervised loss layer.
!*/
const net_type& train (
const std::vector<input_type>& data
);
/*!
requires
- TODO: the net has an unsupervised loss layer.
ensures
- trains an auto-encoder
!*/
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_DNn_CORE_ABSTRACT_H_ DLIB_DNn_CORE_H_
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNn_INPUT_H_
#define DLIB_DNn_INPUT_H_
#include <dlib/matrix.h>
#include <dlib/pixel.h>
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <typename T>
class input
{
public:
// sample_expansion_factor must be > 0
const static unsigned int sample_expansion_factor = 1;
typedef T input_type;
template <typename input_iterator>
void to_tensor (
input_iterator begin,
input_iterator end,
resizable_tensor& data
) const
/*!
requires
- [begin, end) is an iterator range over input_type objects.
ensures
- Converts the iterator range into a tensor and stores it into #data.
- Normally you would have #data.num_samples() == distance(begin,end) but
you can also expand the output by some integer factor so long as the loss
you use can deal with it correctly.
- #data.num_samples() == distance(begin,end)*sample_expansion_factor.
!*/
{
// initialize data to the right size to contain the stuff in the iterator range.
for (input_iterator i = begin; i != end; ++i)
{
matrix<rgb_pixel> temp = *i;
// now copy *i into the right part of data.
}
}
};
// ----------------------------------------------------------------------------------------
template <typename T,long NR, typename MM, typename L>
class input<matrix<T,NR,1,MM,L>>
{
public:
// TODO, maybe we should only allow T to be float? Seems kinda pointless to allow
// double. Don't forget to remove the matrix_cast if we enforce just float.
typedef matrix<T,NR,1,MM,L> input_type;
const static unsigned int sample_expansion_factor = 1;
template <typename input_iterator>
void to_tensor (
input_iterator begin,
input_iterator end,
resizable_tensor& data
) const
/*!
requires
- [begin, end) is an iterator range over input_type objects.
ensures
- converts the iterator range into a tensor and stores it into #data.
- Normally you would have #data.num_samples() == distance(begin,end) but
you can also expand the output by some integer factor so long as the loss
you use can deal with it correctly.
- #data.num_samples() == distance(begin,end)*sample_expansion_factor.
!*/
{
// initialize data to the right size to contain the stuff in the iterator range.
data.set_size(std::distance(begin,end), 1, 1, begin->size());
unsigned long idx = 0;
for (input_iterator i = begin; i != end; ++i)
{
data.set_sample(idx++, matrix_cast<float>(*i));
}
}
};
// ----------------------------------------------------------------------------------------
template <typename T>
class input2
{
public:
input2(){}
input2(const input<T>&) {}
typedef T input_type;
const static unsigned int sample_expansion_factor = 1;
template <typename input_iterator>
void to_tensor (
input_iterator begin,
input_iterator end,
resizable_tensor& data
) const
/*!
requires
- [begin, end) is an iterator range over T objects.
ensures
- converts the iterator range into a tensor and stores it into #data.
- Normally you would have #data.num_samples() == distance(begin,end) but
you can also expand the output by some integer factor so long as the loss
you use can deal with it correctly.
- #data.num_samples() == distance(begin,end)*K where K is an integer >= 1.
!*/
{
// initialize data to the right size to contain the stuff in the iterator range.
for (input_iterator i = begin; i != end; ++i)
{
matrix<rgb_pixel> temp = *i;
// now copy *i into the right part of data.
}
}
};
// ----------------------------------------------------------------------------------------
}
#endif // #define DLIB_DNn_INPUT_H_
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNn_LAYERS_H_
#define DLIB_DNn_LAYERS_H_
#include "layers_abstract.h"
#include "tensor.h"
#include "core.h"
#include <iostream>
#include <string>
#include <dlib/rand.h>
#include <dlib/string.h>
namespace dlib
{
// ----------------------------------------------------------------------------------------
class con_
{
public:
con_()
{}
template <typename SUB_NET>
void setup (const SUB_NET& sub)
{
// TODO
}
template <typename SUB_NET>
void forward(const SUB_NET& sub, resizable_tensor& output)
{
// TODO
}
template <typename SUB_NET>
void backward(const tensor& gradient_input, SUB_NET& sub, tensor& params_grad)
{
// TODO
}
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
private:
resizable_tensor params;
};
template <typename SUB_NET>
using con = add_layer<con_, SUB_NET>;
// ----------------------------------------------------------------------------------------
class fc_
{
public:
fc_() : num_outputs(1)
{
rnd.set_seed("fc_" + cast_to_string(num_outputs));
}
explicit fc_(unsigned long num_outputs_)
{
num_outputs = num_outputs_;
rnd.set_seed("fc_" + cast_to_string(num_outputs));
}
unsigned long get_num_outputs (
) const { return num_outputs; }
template <typename SUB_NET>
void setup (const SUB_NET& sub)
{
num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k();
params.set_size(num_inputs, num_outputs);
std::cout << "fc_::setup() " << params.size() << std::endl;
randomize_parameters(params, num_inputs+num_outputs, rnd);
}
template <typename SUB_NET>
void forward(const SUB_NET& sub, resizable_tensor& output)
{
output.set_size(sub.get_output().num_samples(), num_outputs);
output = mat(sub.get_output())*mat(params);
}
template <typename SUB_NET>
void backward(const tensor& gradient_input, SUB_NET& sub, tensor& params_grad)
{
// d1*W*p1 + d2*W*p2
// total gradient = [d1*W; d2*W; d3*W; ...] == D*W
// compute the gradient of the parameters.
params_grad += trans(mat(sub.get_output()))*mat(gradient_input);
// compute the gradient for the data
sub.get_gradient_input() += mat(gradient_input)*trans(mat(params));
}
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
private:
unsigned long num_outputs;
unsigned long num_inputs;
resizable_tensor params;
dlib::rand rnd;
};
template <typename SUB_NET>
using fc = add_layer<fc_, SUB_NET>;
// ----------------------------------------------------------------------------------------
class relu_
{
public:
relu_()
{
}
template <typename SUB_NET>
void setup (const SUB_NET& sub)
{
}
template <typename SUB_NET>
void forward(const SUB_NET& sub, resizable_tensor& output)
{
output.copy_size(sub.get_output());
output = lowerbound(mat(sub.get_output()), 0);
}
template <typename SUB_NET>
void backward(const tensor& gradient_input, SUB_NET& sub, tensor& params_grad)
{
const float* grad = gradient_input.host();
const float* in = sub.get_output().host();
float* out = sub.get_gradient_input().host();
for (unsigned long i = 0; i < sub.get_output().size(); ++i)
{
if (in[i] > 0)
out[i] += grad[i];
}
}
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
private:
resizable_tensor params;
};
template <typename SUB_NET>
using relu = add_layer<relu_, SUB_NET>;
// ----------------------------------------------------------------------------------------
class multiply_
{
public:
multiply_()
{
}
template <typename SUB_NET>
void setup (const SUB_NET& sub)
{
num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k();
params.set_size(1, num_inputs);
std::cout << "multiply_::setup() " << params.size() << std::endl;
const int num_outputs = num_inputs;
randomize_parameters(params, num_inputs+num_outputs, rnd);
}
template <typename SUB_NET>
void forward(const SUB_NET& sub, resizable_tensor& output)
{
DLIB_CASSERT( sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k() == params.size(), "");
DLIB_CASSERT( sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k() == num_inputs, "");
output.copy_size(sub.get_output());
auto indata = sub.get_output().host();
auto outdata = output.host();
auto paramdata = params.host();
for (int i = 0; i < sub.get_output().num_samples(); ++i)
{
for (int j = 0; j < num_inputs; ++j)
{
*outdata++ = *indata++ * paramdata[j];
}
}
}
template <typename SUB_NET>
void backward(const tensor& gradient_input, SUB_NET& sub, tensor& params_grad)
{
params_grad += sum_rows(pointwise_multiply(mat(sub.get_output()),mat(gradient_input)));
for (long i = 0; i < gradient_input.num_samples(); ++i)
{
sub.get_gradient_input().add_to_sample(i,
pointwise_multiply(rowm(mat(gradient_input),i), mat(params)));
}
}
const tensor& get_layer_params() const { return params; }
tensor& get_layer_params() { return params; }
private:
int num_inputs;
resizable_tensor params;
dlib::rand rnd;
};
template <typename SUB_NET>
using multiply = add_layer<multiply_, SUB_NET>;
// ----------------------------------------------------------------------------------------
}
#endif // #define DLIB_DNn_LAYERS_H_
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_DNn_LAYERS_ABSTRACT_H_
#ifdef DLIB_DNn_LAYERS_ABSTRACT_H_
#include "tensor_abstract.h"
#include "core_abstract.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
class SUB_NET
{
/*!
WHAT THIS OBJECT REPRESENTS
By "Sub net" we mean the part of the network closer to the input. Whenever
you get a SUB_NET it will always have computed its outputs and they will be
available in get_output().
!*/
public:
const tensor& get_output(
) const;
tensor& get_gradient_input(
);
const NEXT_SUB_NET& sub_net(
) const;
NEXT_SUB_NET& sub_net(
);
};
// ----------------------------------------------------------------------------------------
class EXAMPLE_LAYER_
{
/*!
WHAT THIS OBJECT REPRESENTS
Each layer in a deep neural network can be thought of as a function,
f(data,parameters), that takes in a data tensor, some parameters, and
produces an output tensor. You create an entire deep network by composing
these functions. Importantly, you are able to use a wide range of
different functions to accommodate whatever task you are trying to accomplish.
Dlib includes a number of common layer types but if you want to define your
own then you simply implement a class with the same interface as EXAMPLE_LAYER_.
!*/
public:
EXAMPLE_LAYER_(
);
/*!
ensures
- Default constructs this object. This function is not required to do
anything in particular but it is required that layer objects be default
constructable.
!*/
template <typename SUB_NET>
void setup (
const SUB_NET& sub
);
/*!
requires
- SUB_NET implements the SUB_NET interface defined at the top of this file.
ensures
- performs any necessary initial memory allocations and/or sets parameters
to their initial values prior to learning. Therefore, calling setup
destroys any previously learned parameters.
!*/
template <typename SUB_NET>
void forward(
const SUB_NET& sub,
resizable_tensor& output
);
/*!
requires
- SUB_NET implements the SUB_NET interface defined at the top of this file.
- setup() has been called.
ensures
- Runs the output of the sub-network through this layer and stores the
output into #output. In particular, forward() can use any of the outputs
in sub (e.g. sub.get_output(), sub.sub_net().get_output(), etc.) to
compute whatever it wants.
- #output.num_samples() == sub.get_output().num_samples()
!*/
template <typename SUB_NET>
void backward(
const tensor& gradient_input,
SUB_NET& sub,
tensor& params_grad
);
/*!
requires
- SUB_NET implements the SUB_NET interface defined at the top of this file.
- setup() has been called.
- gradient_input has the same dimensions as the output of forward(sub,output).
- have_same_dimensions(sub.get_gradient_input(), sub.get_output()) == true
- have_same_dimensions(params_grad, get_layer_params()) == true
ensures
- This function outputs the gradients of this layer with respect to the
input data from sub and also with respect to this layer's parameters.
These gradients are stored into #sub and #params_grad, respectively. To be
precise, the gradients are taken of a function f(sub,get_layer_params())
which is defined thusly:
- let OUT be the output of forward(sub,OUT).
- let f(sub,get_layer_params()) == dot(OUT, gradient_input)
Then we define the following gradient vectors:
- PARAMETER_GRADIENT == gradient of f(sub,get_layer_params()) with
respect to get_layer_params().
- for all valid I:
- DATA_GRADIENT_I == gradient of f(sub,get_layer_params()) with
respect to layer<I>(sub).get_output() (recall that forward() can
draw inputs from the immediate sub layer, sub.sub_net(), or
any earlier layer. So you must consider the gradients with
respect to all inputs drawn from sub)
Finally, backward() adds these gradients into the output by performing:
- params_grad += PARAMETER_GRADIENT
- for all valid I:
- layer<I>(sub).get_gradient_input() += DATA_GRADIENT_I
!*/
const tensor& get_layer_params(
) const;
/*!
ensures
- returns the parameters that define the behavior of forward().
!*/
tensor& get_layer_params(
);
/*!
ensures
- returns the parameters that define the behavior of forward().
!*/
};
// For each layer you define, always define an add_layer template so that layers can be
// easily composed. Moreover, the convention is that the layer class ends with an _
// while the add_layer template has the same name but without the trailing _.
template <typename SUB_NET>
using EXAMPLE_LAYER = add_layer<EXAMPLE_LAYER_, SUB_NET>;
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class fc_
{
/*!
WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_LAYER_ interface defined above.
In particular, it defines a fully connected layer that takes an input
tensor and multiplies it by a weight matrix and outputs the results.
!*/
public:
fc_(
);
/*!
ensures
- #get_num_outputs() == 1
!*/
explicit fc_(
unsigned long num_outputs
);
/*!
ensures
- #get_num_outputs() == num_outputs
!*/
unsigned long get_num_outputs (
) const;
/*!
ensures
- This layer outputs column vectors that contain get_num_outputs()
elements. That is, the output tensor T from forward() will be such that:
- T.num_samples() == however many samples were given to forward().
- T.nr() == get_num_outputs()
- The rest of the dimensions of T will be 1.
!*/
template <typename SUB_NET> void setup (const SUB_NET& sub);
template <typename SUB_NET> void forward(const SUB_NET& sub, resizable_tensor& output);
template <typename SUB_NET> void backward(const tensor& gradient_input, SUB_NET& sub, tensor& params_grad);
const tensor& get_layer_params() const;
tensor& get_layer_params();
/*!
These functions are implemented as described in the EXAMPLE_LAYER_ interface.
!*/
};
template <typename SUB_NET>
using fc = add_layer<fc_, SUB_NET>;
// ----------------------------------------------------------------------------------------
class relu_
{
public:
relu_(
);
template <typename SUB_NET> void setup (const SUB_NET& sub);
template <typename SUB_NET> void forward(const SUB_NET& sub, resizable_tensor& output);
template <typename SUB_NET> void backward(const tensor& gradient_input, SUB_NET& sub, tensor& params_grad);
const tensor& get_layer_params() const;
tensor& get_layer_params();
/*!
These functions are implemented as described in the EXAMPLE_LAYER_ interface.
!*/
};
template <typename SUB_NET>
using relu = add_layer<relu_, SUB_NET>;
// ----------------------------------------------------------------------------------------
}
#endif // #define DLIB_DNn_LAYERS_H_
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNn_LOSS_H_
#define DLIB_DNn_LOSS_H_
#include "core.h"
#include <dlib/matrix.h>
namespace dlib
{
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class loss_binary_hinge_
{
public:
const static unsigned int sample_expansion_factor = 1;
typedef double label_type;
// Implementing to_label() is optional. If you don't do it then it just means the
// automatic operator() mapping from tensors to outputs is missing from the net object.
template <
typename SUB_TYPE,
typename label_iterator
>
void to_label (
const SUB_TYPE& sub,
label_iterator iter
) const
/*!
requires
- SUB_NET implements the SUB_NET interface defined at the top of layers_abstract.h.
- sub.get_output().num_samples() must be a multiple of sample_expansion_factor.
- iter == an iterator pointing to the beginning of a range of
sub.get_output().num_samples()/sample_expansion_factor elements. In
particular, they must be label_type elements.
!*/
{
const tensor& output_tensor = sub.get_output();
DLIB_CASSERT(output_tensor.nr() == 1 &&
output_tensor.nc() == 1 &&
output_tensor.k() == 1,"");
DLIB_CASSERT(output_tensor.num_samples()%sample_expansion_factor == 0,"");
const float* out_data = output_tensor.host();
for (unsigned long i = 0; i < output_tensor.num_samples(); ++i)
{
*iter++ = out_data[i];
}
}
template <
typename label_iterator,
typename SUB_NET
>
double compute_loss (
const tensor& input_tensor,
label_iterator truth, // TODO, this parameter is optional.
SUB_NET& sub
) const
/*!
requires
- SUB_NET implements the SUB_NET interface defined at the top of layers_abstract.h.
- input_tensor was given as input to the network sub and the outputs are now
visible in sub.get_output(), sub.sub_net().get_output(), etc.
- input_tensor.num_samples() must be a multiple of sample_expansion_factor.
- input_tensor.num_samples() == sub.get_output().num_samples() == grad.num_samples()
- truth == an iterator pointing to the beginning of a range of
input_tensor.num_samples()/sample_expansion_factor elements. In particular,
they must be label_type elements.
- sub.get_gradient_input() has the same dimensions as sub.get_output().
- for all valid i:
- *(truth+i/sample_expansion_factor) is the label of the ith sample in
sub.get_output().
ensures
- #sub.get_gradient_input() == the gradient of the loss with respect to
sub.get_output().
!*/
{
const tensor& output_tensor = sub.get_output();
tensor& grad = sub.get_gradient_input();
// TODO, throw an exception instead of asserting, probably...
DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples(),"");
DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples(),"");
DLIB_CASSERT(output_tensor.nr() == 1 &&
output_tensor.nc() == 1 &&
output_tensor.k() == 1,"");
// The loss we output is the average loss over the mini-batch.
const double scale = 1.0/output_tensor.num_samples();
double loss = 0;
const float* out_data = output_tensor.host();
float* g = grad.host();
for (unsigned long i = 0; i < output_tensor.num_samples(); ++i)
{
const float y = *truth++;
const float temp = 1-y*out_data[i];
if (temp > 0)
{
loss += scale*temp;
g[i] = -scale*y;
}
else
{
g[i] = 0;
}
}
return loss;
}
};
// ----------------------------------------------------------------------------------------
template <typename SUB_NET>
using loss_binary_hinge = add_loss<loss_binary_hinge_, SUB_NET>;
// ----------------------------------------------------------------------------------------
class loss_no_label_
{
public:
//typedef int label_type;
const static unsigned int sample_expansion_factor = 1;
template <
typename SUB_NET
>
double compute_loss (
const tensor& input_tensor,
SUB_NET& sub
) const
/*!
requires
- SUB_NET implements the SUB_NET interface defined at the top of layers_abstract.h.
- input_tensor was given as input to the network sub and the outputs are now
visible in sub.get_output(), sub.sub_net().get_output(), etc.
- input_tensor.num_samples() must be a multiple of sample_expansion_factor.
- input_tensor.num_samples() == sub.get_output().num_samples() == grad.num_samples()
- truth == an iterator pointing to the beginning of a range of
input_tensor.num_samples()/sample_expansion_factor elements. In particular,
they must be label_type elements.
- sub.get_gradient_input() has the same dimensions as sub.get_output().
- for all valid i:
- *(truth+i/sample_expansion_factor) is the label of the ith sample in
sub.get_output().
ensures
- #sub.get_gradient_input() == the gradient of the loss with respect to
sub.get_output().
!*/
{
return 0;
}
};
// ----------------------------------------------------------------------------------------
template <typename SUB_NET>
using loss_no_label = add_loss<loss_no_label_, SUB_NET>;
// ----------------------------------------------------------------------------------------
}
#endif // #define DLIB_DNn_LOSS_H_
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNn_SOLVERS_H_
#define DLIB_DNn_SOLVERS_H_
#include "tensor.h"
#include <iostream>
namespace dlib
{
/*
class EXAMPLE_SOLVER
{
};
*/
struct sgd
{
matrix<float> v;
float weight_decay;
float eps;
float momentum;
sgd(double eps_ = 0.001)
{
weight_decay = 0.0005;
eps = eps_;
//eps = 0.001;
momentum = 0.9;
}
template <typename layer_type>
void operator() (layer_type& l, const tensor& params_grad)
/*!
requires
- l.get_layer_params().size() != 0
- l.get_layer_params() and params_grad have the same dimensions.
!*/
{
if (v.size() != 0)
v = momentum*v - weight_decay*eps*mat(l.get_layer_params()) - eps*mat(params_grad);
else
v = - weight_decay*eps*mat(l.get_layer_params()) - eps*mat(params_grad);
l.get_layer_params() += v;
}
};
}
#endif // #define DLIB_DNn_SOLVERS_H_
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNn_TENSOR_H_
#define DLIB_DNn_TENSOR_H_
#include <memory>
#include <cstring>
#include <dlib/matrix.h>
namespace dlib
{
// ----------------------------------------------------------------------------------------
class gpu_data
{
/*!
CONVENTION
- if (size() != 0) then
- data_host == a pointer to size() floats in CPU memory.
- if (data_device) then
- data_device == a pointer to size() floats in device memory.
- We use the host_current and device_current bools to keep track of which
copy of the data (or both) are most current. e.g. if the CPU has
modified the tensor and it hasn't been copied to the device yet then
host_current==true and device_current == false.
!*/
public:
gpu_data(
) : data_size(0), host_current(true), device_current(false)
{
}
// Not copyable
gpu_data(const gpu_data&) = delete;
gpu_data& operator=(const gpu_data&) = delete;
// but is movable
gpu_data(gpu_data&&) = default;
gpu_data& operator=(gpu_data&&) = default;
void set_size(size_t new_size)
{
if (new_size == 0)
{
data_size = 0;
host_current = true;
device_current = false;
data_host.reset();
data_device.reset();
}
else if (new_size != data_size)
{
data_size = new_size;
host_current = true;
device_current = false;
data_host.reset(new float[new_size]);
data_device.reset();
}
}
void async_copy_to_device()
{
// TODO
}
void async_copy_to_host()
{
// TODO
}
const float* host() const
{
copy_to_host();
return data_host.get();
}
float* host()
{
copy_to_host();
device_current = false;
return data_host.get();
}
const float* device() const
{
copy_to_device();
return data_device.get();
}
float* device()
{
copy_to_device();
host_current = false;
return data_device.get();
}
size_t size() const { return data_size; }
private:
void copy_to_device() const
{
if (!device_current)
{
// TODO, cudamemcpy()
device_current = true;
}
}
void copy_to_host() const
{
if (!host_current)
{
// TODO, cudamemcpy()
host_current = true;
}
}
size_t data_size;
mutable bool host_current;
mutable bool device_current;
std::unique_ptr<float[]> data_host;
std::unique_ptr<float[]> data_device;
};
// ----------------------------------------------------------------------------------------
class tensor
{
/*!
WHAT THIS OBJECT REPRESENTS
!*/
public:
tensor (
) :
m_n(0), m_nr(0), m_nc(0), m_k(0)
{
}
inline virtual ~tensor() = 0;
long num_samples() const { return m_n; }
long nr() const { return m_nr; }
long nc() const { return m_nc; }
long k() const { return m_k; }
size_t size() const { return data.size(); }
void async_copy_to_host()
{
data.async_copy_to_host();
}
void async_copy_to_device()
{
data.async_copy_to_device();
}
/*!
ensures
- begin asynchronously copying this tensor to the GPU.
NOTE that the "get device pointer" routine in this class
will have to do some kind of synchronization that ensures
the copy is finished.
!*/
const float* host() const { return data.host(); }
float* host() { return data.host(); }
const float* device() const { return data.device(); }
float* device() { return data.device(); }
tensor& operator= (float val)
{
// TODO, do on the device if that's where the memory is living right now.
auto d = data.host();
for (size_t i = 0; i < data.size(); ++i)
d[i] = val;
}
template <typename EXP>
tensor& operator= (const matrix_exp<EXP>& item)
{
DLIB_CASSERT(num_samples() == item.nr() &&
nr()*nc()*k() == item.nc(),"");
static_assert((is_same_type<float, typename EXP::type>::value == true),
"To assign a matrix to a tensor the matrix must contain float values");
set_ptrm(data.host(), m_n, m_nr*m_nc*m_k) = item;
return *this;
}
template <typename EXP>
tensor& operator+= (const matrix_exp<EXP>& item)
{
DLIB_CASSERT(num_samples() == item.nr() &&
nr()*nc()*k() == item.nc(),"");
static_assert((is_same_type<float, typename EXP::type>::value == true),
"To assign a matrix to a tensor the matrix must contain float values");
set_ptrm(data.host(), m_n, m_nr*m_nc*m_k) += item;
return *this;
}
template <typename EXP>
tensor& operator-= (const matrix_exp<EXP>& item)
{
DLIB_CASSERT(num_samples() == item.nr() &&
nr()*nc()*k() == item.nc(),"");
static_assert((is_same_type<float, typename EXP::type>::value == true),
"To assign a matrix to a tensor the matrix must contain float values");
set_ptrm(data.host(), m_n, m_nr*m_nc*m_k) -= item;
return *this;
}
template <typename EXP>
void set_sample (
unsigned long idx,
const matrix_exp<EXP>& item
)
{
DLIB_CASSERT(idx < num_samples(), "");
DLIB_CASSERT(item.size() == nr()*nc()*k(), "");
static_assert((is_same_type<float, typename EXP::type>::value == true),
"To assign a matrix to a tensor the matrix must contain float values");
set_ptrm(data.host()+idx*item.size(), item.nr(), item.nc()) = item;
}
template <typename EXP>
void add_to_sample (
unsigned long idx,
const matrix_exp<EXP>& item
)
{
DLIB_CASSERT(idx < num_samples(), "");
DLIB_CASSERT(item.size() == nr()*nc()*k(), "");
static_assert((is_same_type<float, typename EXP::type>::value == true),
"To assign a matrix to a tensor the matrix must contain float values");
set_ptrm(data.host()+idx*item.size(), item.nr(), item.nc()) += item;
}
protected:
tensor& operator= (const tensor& item)
{
m_n = item.m_n;
m_nr = item.m_nr;
m_nc = item.m_nc;
m_k = item.m_k;
data.set_size(item.data.size());
std::memcpy(data.host(), item.data.host(), data.size()*sizeof(float));
return *this;
}
tensor(
const tensor& item
)
{
*this = item;
}
tensor(tensor&& item) = default;
tensor& operator=(tensor&& item) = default;
long m_n;
long m_nr;
long m_nc;
long m_k;
gpu_data data;
};
tensor::~tensor()
{
}
// ----------------------------------------------------------------------------------------
const matrix_op<op_pointer_to_mat<float> > mat (
const tensor& t,
long nr,
long nc
)
{
DLIB_ASSERT(nr > 0 && nc > 0 ,
"\tconst matrix_exp mat(tensor, nr, nc)"
<< "\n\t nr and nc must be bigger than 0"
<< "\n\t nr: " << nr
<< "\n\t nc: " << nc
);
DLIB_ASSERT(nr*nc == t.size() ,
"\tconst matrix_exp mat(tensor, nr, nc)"
<< "\n\t The sizes don't match up."
<< "\n\t nr*nc: " << nr*nc
<< "\n\t t.size(): " << t.size()
);
typedef op_pointer_to_mat<float> op;
return matrix_op<op>(op(t.host(),nr,nc));
}
const matrix_op<op_pointer_to_mat<float> > mat (
const tensor& t
)
{
DLIB_ASSERT(t.size() != 0,
"\tconst matrix_exp mat(tensor)"
<< "\n\t The tensor can't be empty."
);
return mat(t, t.num_samples(), t.size()/t.num_samples());
}
// ----------------------------------------------------------------------------------------
inline bool have_same_dimensions (
const tensor& a,
const tensor& b
)
{
return a.num_samples() == b.num_samples() &&
a.nr() == b.nr() &&
a.nc() == b.nc() &&
a.k() == b.k();
}
// ----------------------------------------------------------------------------------------
class resizable_tensor : public tensor
{
public:
resizable_tensor(
)
{}
explicit resizable_tensor(
long n_, long nr_ = 1, long nc_ = 1, long k_ = 1
)
{
set_size(n_,nr_,nc_,k_);
}
resizable_tensor(const resizable_tensor&) = default;
resizable_tensor(resizable_tensor&&) = default;
void clear(
)
{
set_size(0,0,0,0);
}
void copy_size (
const tensor& item
)
/*!
ensures
- resizes *this so that: have_same_dimensions(#*this, item)==true
!*/
{
set_size(item.num_samples(), item.nr(), item.nc(), item.k());
}
resizable_tensor& operator= (float val)
{
tensor::operator=(val);
return *this;
}
template <typename EXP>
resizable_tensor& operator= (const matrix_exp<EXP>& item)
{
tensor::operator=(item);
return *this;
}
template <typename EXP>
resizable_tensor& operator+= (const matrix_exp<EXP>& item)
{
tensor::operator+=(item);
return *this;
}
template <typename EXP>
resizable_tensor& operator-= (const matrix_exp<EXP>& item)
{
tensor::operator-=(item);
return *this;
}
template <typename EXP>
void set_sample (
unsigned long idx,
const matrix_exp<EXP>& item
)
{
tensor::set_sample(idx, item);
}
template <typename EXP>
void add_to_sample (
unsigned long idx,
const matrix_exp<EXP>& item
)
{
tensor::add_to_sample(idx, item);
}
resizable_tensor& operator= (const resizable_tensor&) = default;
resizable_tensor& operator= (resizable_tensor&&) = default;
resizable_tensor& operator= (const tensor& x)
{
tensor::operator=(x);
return *this;
}
void set_size(
long n_, long nr_ = 1, long nc_ = 1, long k_ = 1
)
{
m_n = n_;
m_nr = nr_;
m_nc = nc_;
m_k = k_;
data.set_size(m_n*m_nr*m_nc*m_k);
}
};
// ----------------------------------------------------------------------------------------
inline double dot(
const tensor& a,
const tensor& b
)
{
DLIB_CASSERT(a.size() == b.size(), "");
const float* da = a.host();
const float* db = b.host();
double sum = 0;
for (size_t i = 0; i < a.size(); ++i)
sum += da[i]*db[i];
return sum;
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_DNn_TENSOR_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment