Commit ab605d15 authored by Davis King's avatar Davis King

Added ADAM solver.

parent f42b4ad2
......@@ -417,6 +417,46 @@ namespace dlib
}
}
// -----------------------------------------------------------------------------------
void compute_adam_update (
tensor& s,
tensor& m,
tensor& v,
const float t,
const float learning_rate,
const float weight_decay,
const float momentum1,
const float momentum2,
const tensor& params,
const tensor& params_grad
)
{
DLIB_CASSERT(s.size() == m.size() &&
s.size() == v.size() &&
s.size() == params.size() &&
s.size() == params_grad.size(),"");
const float eps = 1e-8;
const float alpha = learning_rate*std::sqrt(1-std::pow(momentum2,t))/(1-std::pow(momentum1, t));
// The loop is equivalent to doing this:
// m = momentum1*m + (1-momentum1) * (weight_decay*params + params_grad);
// v = momentum2*v + (1-momentum2)*squared(weight_decay*params + params_grad);
// s = -alpha*m/(sqrt(v) + eps);
auto pm = m.host();
auto pv = v.host();
auto ps = s.host_write_only();
auto pparams = params.host();
auto ppgrad = params_grad.host();
for (size_t i = 0; i < params.size(); ++i)
{
float g = weight_decay*pparams[i] + ppgrad[i];
pm[i] = momentum1*pm[i] + (1-momentum1)*g;
pv[i] = momentum2*pv[i] + (1-momentum2)*g*g;
ps[i] = -alpha*pm[i]/(std::sqrt(pv[i]) + eps);
}
}
// -----------------------------------------------------------------------------------
void batch_normalize_inference (
......
......@@ -101,6 +101,21 @@ namespace dlib
const tensor& B
);
// -----------------------------------------------------------------------------------
void compute_adam_update (
tensor& s,
tensor& m,
tensor& v,
const float t,
const float learning_rate,
const float weight_decay,
const float momentum1,
const float momentum2,
const tensor& params,
const tensor& params_grad
);
// -----------------------------------------------------------------------------------
void batch_normalize_inference (
......
......@@ -386,6 +386,59 @@ namespace dlib
}
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_compute_adam_update(
size_t n,
float* s,
float* m,
float* v,
const float alpha,
const float weight_decay,
const float momentum1,
const float momentum2,
const float* params,
const float* params_grad
)
{
const float eps = 1e-8;
// The loop is equivalent to doing this:
// m = momentum1*m + (1-momentum1) * (weight_decay*params + params_grad);
// v = momentum2*v + (1-momentum2)*squared(weight_decay*params + params_grad);
// s = -alpha*m/(sqrt(v) + eps);
for (auto i : grid_stride_range(0, n))
{
float g = (weight_decay*params[i] + params_grad[i]);
m[i] = momentum1*m[i] + (1-momentum1)*g;
v[i] = momentum2*v[i] + (1-momentum2)*g*g;
s[i] = -alpha*m[i]/(std::sqrt(v[i]) + eps);
}
}
void compute_adam_update (
tensor& s,
tensor& m,
tensor& v,
const float t,
const float learning_rate,
const float weight_decay,
const float momentum1,
const float momentum2,
const tensor& params,
const tensor& params_grad
)
{
DLIB_CASSERT(s.size() == m.size() &&
s.size() == v.size() &&
s.size() == params.size() &&
s.size() == params_grad.size(),"");
const float alpha = learning_rate*std::sqrt(1-std::pow(momentum2,t))/(1-std::pow(momentum1, t));
launch_kernel(_cuda_compute_adam_update,max_jobs(s.size()),
s.size(), s.device(), m.device(), v.device(), alpha, weight_decay,
momentum1, momentum2, params.device(), params_grad.device());
}
// -----------------------------------------------------------------------------------
__global__ void _cuda_affine_transform_conv(float* d, const float* s, size_t n, const float* A, const float* B, size_t bs, size_t ks)
......
......@@ -97,6 +97,21 @@ namespace dlib
const tensor& B
);
// ----------------------------------------------------------------------------------------
void compute_adam_update (
tensor& s,
tensor& m,
tensor& v,
const float t,
const float learning_rate,
const float weight_decay,
const float momentum1,
const float momentum2,
const tensor& params,
const tensor& params_grad
);
// -----------------------------------------------------------------------------------
void assign_bias_gradient (
......
......@@ -80,6 +80,101 @@ namespace dlib
float momentum;
};
// ----------------------------------------------------------------------------------------
class adam
{
public:
adam(
float learning_rate_ = 0.001,
float weight_decay_ = 0.0005,
float momentum1_ = 0.9,
float momentum2_ = 0.999
)
{
weight_decay = weight_decay_;
learning_rate = learning_rate_;
momentum1 = momentum1_;
momentum2 = momentum2_;
t = 0;
}
float get_momentum1 (
) const { return momentum1; }
float get_momentum2 (
) const { return momentum2; }
float get_weight_decay (
) const { return weight_decay; }
float get_learning_rate (
) const { return learning_rate; }
const tensor& operator() (
const tensor& params,
const tensor& params_grad
)
{
DLIB_CASSERT(params.size() != 0,"");
if (v.size() == 0)
{
m.copy_size(params_grad);
m = 0;
v.copy_size(params_grad);
v = 0;
s.copy_size(params_grad);
}
++t;
tt::compute_adam_update(s, m, v, t, learning_rate, weight_decay, momentum1, momentum2, params, params_grad);
return s;
}
friend void serialize(const adam& item, std::ostream& out)
{
serialize("adam", out);
serialize(item.m, out);
serialize(item.v, out);
serialize(item.s, out);
serialize(item.weight_decay, out);
serialize(item.learning_rate, out);
serialize(item.momentum1, out);
serialize(item.momentum2, out);
serialize(item.t, out);
}
friend void deserialize(adam& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "adam")
throw serialization_error("Unexpected version found while deserializing dlib::adam.");
deserialize(item.m, in);
deserialize(item.v, in);
deserialize(item.s, in);
deserialize(item.weight_decay, in);
deserialize(item.learning_rate, in);
deserialize(item.momentum1, in);
deserialize(item.momentum2, in);
deserialize(item.t, in);
}
private:
resizable_tensor m;
resizable_tensor v;
resizable_tensor s;
float weight_decay;
float learning_rate;
float momentum1;
float momentum2;
float t;
};
// ----------------------------------------------------------------------------------------
}
......
......@@ -104,6 +104,52 @@ namespace dlib
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
class adam
{
/*!
WHAT THIS OBJECT REPRESENTS
This object implements the EXAMPLE_SOLVER interface defined above. In
particular, it implements the ADAM parameter update method described in the
paper:
Kingma, Diederik P., and Jimmy Ba Adam. "A method for stochastic
optimization." International Conference on Learning Representation. 2015.
!*/
public:
adam(
float learning_rate = 0.001,
float weight_decay = 0.0005,
float momentum1 = 0.9,
float momentum2 = 0.999
);
/*!
requires
- learning_rate > 0
- weight_decay >= 0
- 0 <= momentum1 < 1
- 0 <= momentum2 < 1
ensures
- #get_learning_rate() == learning_rate
- #get_weight_decay() == weight_decay
- #get_momentum1() == momentum1
- #get_momentum2() == momentum2
!*/
float get_learning_rate () const;
float get_weight_decay () const;
float get_momentum1 () const;
float get_momentum2 () const;
};
void serialize(const adam& item, std::ostream& out);
void deserialize(adam& item, std::istream& in);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
}
......
......@@ -210,6 +210,30 @@ namespace dlib { namespace tt
#endif
}
// ----------------------------------------------------------------------------------------
void compute_adam_update (
tensor& s,
tensor& m,
tensor& v,
const float t,
const float learning_rate,
const float weight_decay,
const float momentum1,
const float momentum2,
const tensor& params,
const tensor& params_grad
)
{
#ifdef DLIB_USE_CUDA
cuda::compute_adam_update(s, m, v, t, learning_rate, weight_decay, momentum1,
momentum2, params, params_grad);
#else
cpu::compute_adam_update(s, m, v, t, learning_rate, weight_decay, momentum1,
momentum2, params, params_grad);
#endif
}
// ----------------------------------------------------------------------------------------
void batch_normalize_inference (
......
......@@ -247,6 +247,36 @@ namespace dlib { namespace tt
#dest(n,k,r,c) == A(k)*src(n,k,r,c) + B(k).
!*/
// ----------------------------------------------------------------------------------------
void compute_adam_update (
tensor& s,
tensor& m,
tensor& v,
const float t,
const float learning_rate,
const float weight_decay,
const float momentum1,
const float momentum2,
const tensor& params,
const tensor& params_grad
);
/*!
requires
- s.size() == m.size() = v.size() == params.size() == params_grad.size()
- t > 0
- learning_rate > 0
- weight_decay >= 0
- 0 <= momentum1 < 1
- 0 <= momentum2 < 1
ensures
- This function implements the ADAM parameter update method described in the paper:
Kingma, Diederik P., and Jimmy Ba Adam. "A method for stochastic
optimization." International Conference on Learning Representation. 2015.
Specifically, it implements the method shown as Algorithm 1.
- #s is the update vector that should be added to the parameters.
!*/
// ----------------------------------------------------------------------------------------
void batch_normalize_inference (
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment