Commit 8a0e2e4d authored by Davis King's avatar Davis King

Some more cleanup. Also filled out the solver spec.

parent 5c8a2a4c
......@@ -1401,6 +1401,6 @@ namespace dlib
}
#endif // #define DLIB_DNn_CORE_H_
#endif // DLIB_DNn_CORE_H_
......@@ -133,5 +133,5 @@ namespace dlib
}
#endif // #define DLIB_DNn_INPUT_H_
#endif // DLIB_DNn_INPUT_H_
......@@ -237,6 +237,6 @@ namespace dlib
}
#endif // #define DLIB_DNn_LAYERS_H_
#endif // DLIB_DNn_LAYERS_H_
......@@ -48,10 +48,13 @@ namespace dlib
f(data,parameters), that takes in a data tensor, some parameters, and
produces an output tensor. You create an entire deep network by composing
these functions. Importantly, you are able to use a wide range of
different functions to accommodate whatever task you are trying to accomplish.
different functions to accommodate the task you are trying to accomplish.
Dlib includes a number of common layer types but if you want to define your
own then you simply implement a class with the same interface as EXAMPLE_LAYER_.
own then you simply implement a class with the same interface as
EXAMPLE_LAYER_.
Note that there is no dlib::EXAMPLE_LAYER_ type. It is shown here purely
to document the interface that a layer object must implement.
!*/
public:
......@@ -234,5 +237,5 @@ namespace dlib
}
#endif // #define DLIB_DNn_LAYERS_H_
#endif // DLIB_DNn_LAYERS_ABSTRACT_H_
......@@ -170,6 +170,6 @@ namespace dlib
}
#endif // #define DLIB_DNn_LOSS_H_
#endif // DLIB_DNn_LOSS_H_
......@@ -3,53 +3,60 @@
#ifndef DLIB_DNn_SOLVERS_H_
#define DLIB_DNn_SOLVERS_H_
#include "solvers_abstract.h"
#include "tensor.h"
#include <iostream>
namespace dlib
{
/*
class EXAMPLE_SOLVER
{
};
*/
struct sgd
class sgd
{
public:
matrix<float> v;
float weight_decay;
float eps;
float momentum;
sgd(double eps_ = 0.001)
sgd(
float learning_rate_ = 0.001,
float weight_decay_ = 0.0005,
float momentum_ = 0.9
)
{
weight_decay = 0.0005;
eps = eps_;
//eps = 0.001;
momentum = 0.9;
weight_decay = weight_decay_;
learning_rate = learning_rate_;
momentum = momentum_;
}
template <typename layer_type>
void operator() (layer_type& l, const tensor& params_grad)
/*!
requires
- l.get_layer_params().size() != 0
- l.get_layer_params() and params_grad have the same dimensions.
!*/
float get_momentum (
) const { return momentum; }
float get_weight_decay (
) const { return weight_decay; }
float get_learning_rate (
) const { return learning_rate; }
template <typename LAYER_DETAILS>
void operator() (
LAYER_DETAILS& l,
const tensor& params_grad
)
{
DLIB_CASSERT(l.get_layer_params().size() != 0,"");
if (v.size() != 0)
v = momentum*v - weight_decay*eps*mat(l.get_layer_params()) - eps*mat(params_grad);
v = momentum*v - weight_decay*learning_rate*mat(l.get_layer_params()) - learning_rate*mat(params_grad);
else
v = - weight_decay*eps*mat(l.get_layer_params()) - eps*mat(params_grad);
v = - weight_decay*learning_rate*mat(l.get_layer_params()) - learning_rate*mat(params_grad);
l.get_layer_params() += v;
}
private:
matrix<float> v;
float weight_decay;
float learning_rate;
float momentum;
};
}
#endif // #define DLIB_DNn_SOLVERS_H_
#endif // DLIB_DNn_SOLVERS_H_
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_DNn_SOLVERS_ABSTRACT_H_
#ifdef DLIB_DNn_SOLVERS_ABSTRACT_H_
#include "tensor_abstract.h"
#include <iostream>
namespace dlib
{
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class EXAMPLE_SOLVER
{
/*!
WHAT THIS OBJECT REPRESENTS
A solver defines the parameter update rule for a single layer in a deep
neural network. It takes a parameter gradient vector and a layer and
updates the layer's parameters. Importantly, each solver instance is used
with only one layer in a network. This allows us to define solvers that
have per layer state, for example, a solver may keep a momentum term and
apply it to its update rule.
Note that there is no dlib::EXAMPLE_SOLVER type. It is shown here purely
to document the interface that a solver object must implement.
!*/
public:
EXAMPLE_SOLVER(
);
template <typename LAYER_DETAILS>
void operator() (
LAYER_DETAILS& l,
const tensor& params_grad
);
/*!
requires
- LAYER_DETAILS implements the EXAMPLE_LAYER_ interface defined in layers_abstract.h.
- l.get_layer_params().size() != 0
- l.get_layer_params() and params_grad have the same dimensions.
- When this function is invoked on a particular solver instance, it is
always supplied with the same LAYER_DETAILS object.
ensures
- Updates the parameters in l. That is, l.get_layer_params() is modified
based on the parameter gradient vector stored in params_grad.
!*/
};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class sgd
{
/*!
WHAT THIS OBJECT REPRESENTS
This object implements the EXAMPLE_SOLVER interface defined above. It is a
basic stochastic gradient descent solver which uses momentum and weight
decay. In particular, it performs the following update each time the
solver is invoked:
v = momentum*v - weight_decay*learning_rate*l.get_layer_params() - learning_rate*params_grad;
l.get_layer_params() += v;
Here v is a momentum term that is remembered by the solver from one
invocation of operator() to the next.
!*/
public:
sgd(
float learning_rate = 0.001,
float weight_decay = 0.0005,
float momentum = 0.9
);
/*!
requires
- learning_rate > 0
- weight_decay >= 0
- momentum >= 0
ensures
- #get_learning_rate() == learning_rate
- #get_weight_decay() == weight_decay
- #get_momentum() == momentum
!*/
float get_learning_rate () const;
float get_weight_decay () const;
float get_momentum () const;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_DNn_SOLVERS_ABSTRACT_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment