Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
ab605d15
Commit
ab605d15
authored
Mar 27, 2016
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added ADAM solver.
parent
f42b4ad2
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
318 additions
and
0 deletions
+318
-0
cpu_dlib.cpp
dlib/dnn/cpu_dlib.cpp
+40
-0
cpu_dlib.h
dlib/dnn/cpu_dlib.h
+15
-0
cuda_dlib.cu
dlib/dnn/cuda_dlib.cu
+53
-0
cuda_dlib.h
dlib/dnn/cuda_dlib.h
+15
-0
solvers.h
dlib/dnn/solvers.h
+95
-0
solvers_abstract.h
dlib/dnn/solvers_abstract.h
+46
-0
tensor_tools.cpp
dlib/dnn/tensor_tools.cpp
+24
-0
tensor_tools.h
dlib/dnn/tensor_tools.h
+30
-0
No files found.
dlib/dnn/cpu_dlib.cpp
View file @
ab605d15
...
@@ -417,6 +417,46 @@ namespace dlib
...
@@ -417,6 +417,46 @@ namespace dlib
}
}
}
}
// -----------------------------------------------------------------------------------
void
compute_adam_update
(
tensor
&
s
,
tensor
&
m
,
tensor
&
v
,
const
float
t
,
const
float
learning_rate
,
const
float
weight_decay
,
const
float
momentum1
,
const
float
momentum2
,
const
tensor
&
params
,
const
tensor
&
params_grad
)
{
DLIB_CASSERT
(
s
.
size
()
==
m
.
size
()
&&
s
.
size
()
==
v
.
size
()
&&
s
.
size
()
==
params
.
size
()
&&
s
.
size
()
==
params_grad
.
size
(),
""
);
const
float
eps
=
1e-8
;
const
float
alpha
=
learning_rate
*
std
::
sqrt
(
1
-
std
::
pow
(
momentum2
,
t
))
/
(
1
-
std
::
pow
(
momentum1
,
t
));
// The loop is equivalent to doing this:
// m = momentum1*m + (1-momentum1) * (weight_decay*params + params_grad);
// v = momentum2*v + (1-momentum2)*squared(weight_decay*params + params_grad);
// s = -alpha*m/(sqrt(v) + eps);
auto
pm
=
m
.
host
();
auto
pv
=
v
.
host
();
auto
ps
=
s
.
host_write_only
();
auto
pparams
=
params
.
host
();
auto
ppgrad
=
params_grad
.
host
();
for
(
size_t
i
=
0
;
i
<
params
.
size
();
++
i
)
{
float
g
=
weight_decay
*
pparams
[
i
]
+
ppgrad
[
i
];
pm
[
i
]
=
momentum1
*
pm
[
i
]
+
(
1
-
momentum1
)
*
g
;
pv
[
i
]
=
momentum2
*
pv
[
i
]
+
(
1
-
momentum2
)
*
g
*
g
;
ps
[
i
]
=
-
alpha
*
pm
[
i
]
/
(
std
::
sqrt
(
pv
[
i
])
+
eps
);
}
}
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
void
batch_normalize_inference
(
void
batch_normalize_inference
(
...
...
dlib/dnn/cpu_dlib.h
View file @
ab605d15
...
@@ -101,6 +101,21 @@ namespace dlib
...
@@ -101,6 +101,21 @@ namespace dlib
const
tensor
&
B
const
tensor
&
B
);
);
// -----------------------------------------------------------------------------------
void
compute_adam_update
(
tensor
&
s
,
tensor
&
m
,
tensor
&
v
,
const
float
t
,
const
float
learning_rate
,
const
float
weight_decay
,
const
float
momentum1
,
const
float
momentum2
,
const
tensor
&
params
,
const
tensor
&
params_grad
);
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
void
batch_normalize_inference
(
void
batch_normalize_inference
(
...
...
dlib/dnn/cuda_dlib.cu
View file @
ab605d15
...
@@ -386,6 +386,59 @@ namespace dlib
...
@@ -386,6 +386,59 @@ namespace dlib
}
}
}
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_compute_adam_update(
size_t n,
float* s,
float* m,
float* v,
const float alpha,
const float weight_decay,
const float momentum1,
const float momentum2,
const float* params,
const float* params_grad
)
{
const float eps = 1e-8;
// The loop is equivalent to doing this:
// m = momentum1*m + (1-momentum1) * (weight_decay*params + params_grad);
// v = momentum2*v + (1-momentum2)*squared(weight_decay*params + params_grad);
// s = -alpha*m/(sqrt(v) + eps);
for (auto i : grid_stride_range(0, n))
{
float g = (weight_decay*params[i] + params_grad[i]);
m[i] = momentum1*m[i] + (1-momentum1)*g;
v[i] = momentum2*v[i] + (1-momentum2)*g*g;
s[i] = -alpha*m[i]/(std::sqrt(v[i]) + eps);
}
}
void compute_adam_update (
tensor& s,
tensor& m,
tensor& v,
const float t,
const float learning_rate,
const float weight_decay,
const float momentum1,
const float momentum2,
const tensor& params,
const tensor& params_grad
)
{
DLIB_CASSERT(s.size() == m.size() &&
s.size() == v.size() &&
s.size() == params.size() &&
s.size() == params_grad.size(),"");
const float alpha = learning_rate*std::sqrt(1-std::pow(momentum2,t))/(1-std::pow(momentum1, t));
launch_kernel(_cuda_compute_adam_update,max_jobs(s.size()),
s.size(), s.device(), m.device(), v.device(), alpha, weight_decay,
momentum1, momentum2, params.device(), params_grad.device());
}
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
__global__ void _cuda_affine_transform_conv(float* d, const float* s, size_t n, const float* A, const float* B, size_t bs, size_t ks)
__global__ void _cuda_affine_transform_conv(float* d, const float* s, size_t n, const float* A, const float* B, size_t bs, size_t ks)
...
...
dlib/dnn/cuda_dlib.h
View file @
ab605d15
...
@@ -97,6 +97,21 @@ namespace dlib
...
@@ -97,6 +97,21 @@ namespace dlib
const
tensor
&
B
const
tensor
&
B
);
);
// ----------------------------------------------------------------------------------------
void
compute_adam_update
(
tensor
&
s
,
tensor
&
m
,
tensor
&
v
,
const
float
t
,
const
float
learning_rate
,
const
float
weight_decay
,
const
float
momentum1
,
const
float
momentum2
,
const
tensor
&
params
,
const
tensor
&
params_grad
);
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
void
assign_bias_gradient
(
void
assign_bias_gradient
(
...
...
dlib/dnn/solvers.h
View file @
ab605d15
...
@@ -80,6 +80,101 @@ namespace dlib
...
@@ -80,6 +80,101 @@ namespace dlib
float
momentum
;
float
momentum
;
};
};
// ----------------------------------------------------------------------------------------
class
adam
{
public
:
adam
(
float
learning_rate_
=
0
.
001
,
float
weight_decay_
=
0
.
0005
,
float
momentum1_
=
0
.
9
,
float
momentum2_
=
0
.
999
)
{
weight_decay
=
weight_decay_
;
learning_rate
=
learning_rate_
;
momentum1
=
momentum1_
;
momentum2
=
momentum2_
;
t
=
0
;
}
float
get_momentum1
(
)
const
{
return
momentum1
;
}
float
get_momentum2
(
)
const
{
return
momentum2
;
}
float
get_weight_decay
(
)
const
{
return
weight_decay
;
}
float
get_learning_rate
(
)
const
{
return
learning_rate
;
}
const
tensor
&
operator
()
(
const
tensor
&
params
,
const
tensor
&
params_grad
)
{
DLIB_CASSERT
(
params
.
size
()
!=
0
,
""
);
if
(
v
.
size
()
==
0
)
{
m
.
copy_size
(
params_grad
);
m
=
0
;
v
.
copy_size
(
params_grad
);
v
=
0
;
s
.
copy_size
(
params_grad
);
}
++
t
;
tt
::
compute_adam_update
(
s
,
m
,
v
,
t
,
learning_rate
,
weight_decay
,
momentum1
,
momentum2
,
params
,
params_grad
);
return
s
;
}
friend
void
serialize
(
const
adam
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"adam"
,
out
);
serialize
(
item
.
m
,
out
);
serialize
(
item
.
v
,
out
);
serialize
(
item
.
s
,
out
);
serialize
(
item
.
weight_decay
,
out
);
serialize
(
item
.
learning_rate
,
out
);
serialize
(
item
.
momentum1
,
out
);
serialize
(
item
.
momentum2
,
out
);
serialize
(
item
.
t
,
out
);
}
friend
void
deserialize
(
adam
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"adam"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::adam."
);
deserialize
(
item
.
m
,
in
);
deserialize
(
item
.
v
,
in
);
deserialize
(
item
.
s
,
in
);
deserialize
(
item
.
weight_decay
,
in
);
deserialize
(
item
.
learning_rate
,
in
);
deserialize
(
item
.
momentum1
,
in
);
deserialize
(
item
.
momentum2
,
in
);
deserialize
(
item
.
t
,
in
);
}
private
:
resizable_tensor
m
;
resizable_tensor
v
;
resizable_tensor
s
;
float
weight_decay
;
float
learning_rate
;
float
momentum1
;
float
momentum2
;
float
t
;
};
// ----------------------------------------------------------------------------------------
}
}
...
...
dlib/dnn/solvers_abstract.h
View file @
ab605d15
...
@@ -104,6 +104,52 @@ namespace dlib
...
@@ -104,6 +104,52 @@ namespace dlib
provides serialization support
provides serialization support
!*/
!*/
// ----------------------------------------------------------------------------------------
class
adam
{
/*!
WHAT THIS OBJECT REPRESENTS
This object implements the EXAMPLE_SOLVER interface defined above. In
particular, it implements the ADAM parameter update method described in the
paper:
Kingma, Diederik P., and Jimmy Ba Adam. "A method for stochastic
optimization." International Conference on Learning Representation. 2015.
!*/
public
:
adam
(
float
learning_rate
=
0
.
001
,
float
weight_decay
=
0
.
0005
,
float
momentum1
=
0
.
9
,
float
momentum2
=
0
.
999
);
/*!
requires
- learning_rate > 0
- weight_decay >= 0
- 0 <= momentum1 < 1
- 0 <= momentum2 < 1
ensures
- #get_learning_rate() == learning_rate
- #get_weight_decay() == weight_decay
- #get_momentum1() == momentum1
- #get_momentum2() == momentum2
!*/
float
get_learning_rate
()
const
;
float
get_weight_decay
()
const
;
float
get_momentum1
()
const
;
float
get_momentum2
()
const
;
};
void
serialize
(
const
adam
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
adam
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
}
}
...
...
dlib/dnn/tensor_tools.cpp
View file @
ab605d15
...
@@ -210,6 +210,30 @@ namespace dlib { namespace tt
...
@@ -210,6 +210,30 @@ namespace dlib { namespace tt
#endif
#endif
}
}
// ----------------------------------------------------------------------------------------
void
compute_adam_update
(
tensor
&
s
,
tensor
&
m
,
tensor
&
v
,
const
float
t
,
const
float
learning_rate
,
const
float
weight_decay
,
const
float
momentum1
,
const
float
momentum2
,
const
tensor
&
params
,
const
tensor
&
params_grad
)
{
#ifdef DLIB_USE_CUDA
cuda
::
compute_adam_update
(
s
,
m
,
v
,
t
,
learning_rate
,
weight_decay
,
momentum1
,
momentum2
,
params
,
params_grad
);
#else
cpu
::
compute_adam_update
(
s
,
m
,
v
,
t
,
learning_rate
,
weight_decay
,
momentum1
,
momentum2
,
params
,
params_grad
);
#endif
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void
batch_normalize_inference
(
void
batch_normalize_inference
(
...
...
dlib/dnn/tensor_tools.h
View file @
ab605d15
...
@@ -247,6 +247,36 @@ namespace dlib { namespace tt
...
@@ -247,6 +247,36 @@ namespace dlib { namespace tt
#dest(n,k,r,c) == A(k)*src(n,k,r,c) + B(k).
#dest(n,k,r,c) == A(k)*src(n,k,r,c) + B(k).
!*/
!*/
// ----------------------------------------------------------------------------------------
void
compute_adam_update
(
tensor
&
s
,
tensor
&
m
,
tensor
&
v
,
const
float
t
,
const
float
learning_rate
,
const
float
weight_decay
,
const
float
momentum1
,
const
float
momentum2
,
const
tensor
&
params
,
const
tensor
&
params_grad
);
/*!
requires
- s.size() == m.size() = v.size() == params.size() == params_grad.size()
- t > 0
- learning_rate > 0
- weight_decay >= 0
- 0 <= momentum1 < 1
- 0 <= momentum2 < 1
ensures
- This function implements the ADAM parameter update method described in the paper:
Kingma, Diederik P., and Jimmy Ba Adam. "A method for stochastic
optimization." International Conference on Learning Representation. 2015.
Specifically, it implements the method shown as Algorithm 1.
- #s is the update vector that should be added to the parameters.
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void
batch_normalize_inference
(
void
batch_normalize_inference
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment