Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
b2faad0d
Commit
b2faad0d
authored
Oct 23, 2016
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added l2normalize_ layer.
parent
a390f109
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
451 additions
and
1 deletion
+451
-1
cuda_dlib.cu
dlib/dnn/cuda_dlib.cu
+129
-0
cuda_dlib.h
dlib/dnn/cuda_dlib.h
+29
-0
layers.h
dlib/dnn/layers.h
+92
-0
layers_abstract.h
dlib/dnn/layers_abstract.h
+50
-0
tensor_tools.cpp
dlib/dnn/tensor_tools.cpp
+68
-0
tensor_tools.h
dlib/dnn/tensor_tools.h
+60
-0
dnn.cpp
dlib/test/dnn.cpp
+23
-1
No files found.
dlib/dnn/cuda_dlib.cu
View file @
b2faad0d
...
...
@@ -108,6 +108,135 @@ namespace dlib
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
__global__ void _cuda_inverse_norms(float* invnorms, const float* data, size_t nr, size_t nc, const float eps)
{
// initialize invnorms before we begin.
for (auto i : grid_stride_range(0, nr))
invnorms[i] = eps;
__syncthreads();
for (auto i : grid_stride_range_y(0, nr))
{
auto p = data + i*nc;
float temp = 0;
for (auto j : grid_stride_range(0, nc))
temp += p[j]*p[j];
// and store the sum into invnorms[i]
warp_reduce_atomic_add(invnorms[i], temp);
}
__syncthreads();
for (auto j : grid_stride_range(0, nr))
{
invnorms[j] = 1.0/std::sqrt(invnorms[j]);
}
}
void inverse_norms (
resizable_tensor& invnorms,
const tensor& data,
const double eps
)
{
invnorms.set_size(data.num_samples());
launch_kernel(_cuda_inverse_norms,max_jobs(data.size()), invnorms.device(), data.device(), data.num_samples(), data.size()/data.num_samples(), eps);
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_dot_prods(float* out, const float* lhs, const float* rhs, size_t nr, size_t nc)
{
// initialize out before we begin.
for (auto i : grid_stride_range(0, nr))
out[i] = 0;
__syncthreads();
for (auto i : grid_stride_range_y(0, nr))
{
auto l = lhs + i*nc;
auto r = rhs + i*nc;
float temp = 0;
for (auto j : grid_stride_range(0, nc))
temp += l[j]*r[j];
// and store the sum into out[i]
warp_reduce_atomic_add(out[i], temp);
}
}
void dot_prods (
resizable_tensor& out,
const tensor& lhs,
const tensor& rhs
)
{
out.set_size(lhs.num_samples());
launch_kernel(_cuda_dot_prods, max_jobs(lhs.size()), out.device(), lhs.device(), rhs.device(), lhs.num_samples(), lhs.size()/lhs.num_samples());
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_scale_rows(float* out, const float* m, const float* v, size_t nr, size_t nc)
{
for (auto j : grid_stride_range(0, nr*nc))
{
out[j] = m[j]*v[j/nc];
}
}
void scale_rows (
tensor& out,
const tensor& m,
const tensor& v
)
{
launch_kernel(_cuda_scale_rows, max_jobs(m.size()), out.device(), m.device(), v.device(), m.num_samples(), m.size()/m.num_samples());
}
// ----------------------------------------------------------------------------------------
__global__ void _cuda_scale_rows2(float* out, const float* m1, const float* m2, const float* v1, const float* v2, size_t nr, size_t nc)
{
for (auto j : grid_stride_range(0, nr*nc))
{
out[j] = (m1[j] - m2[j]*v1[j/nc]) * v2[j/nc];
}
}
__global__ void _cuda_scale_rows2_beta(const float beta, float* out, const float* m1, const float* m2, const float* v1, const float* v2, size_t nr, size_t nc)
{
for (auto j : grid_stride_range(0, nr*nc))
{
out[j] = beta*out[j] + (m1[j] - m2[j]*v1[j/nc]) * v2[j/nc];
}
}
void scale_rows2 (
float beta,
tensor& out,
const tensor& m1,
const tensor& m2,
const tensor& v1,
const tensor& v2
)
{
if (beta == 0)
{
launch_kernel(_cuda_scale_rows2, max_jobs(m1.size()), out.device(),
m1.device(), m2.device(), v1.device(), v2.device(), m1.num_samples(),
m1.size()/m1.num_samples());
}
else
{
launch_kernel(_cuda_scale_rows2_beta, max_jobs(m1.size()), beta,
out.device(), m1.device(), m2.device(), v1.device(), v2.device(),
m1.num_samples(), m1.size()/m1.num_samples());
}
}
// -----------------------------------------------------------------------------------
__global__ void _cuda_multiply1(float* d, const float* s1, const float* s2, size_t n)
...
...
dlib/dnn/cuda_dlib.h
View file @
b2faad0d
...
...
@@ -108,6 +108,35 @@ namespace dlib
// -----------------------------------------------------------------------------------
void
inverse_norms
(
resizable_tensor
&
invnorms
,
const
tensor
&
data
,
const
double
eps
);
void
dot_prods
(
resizable_tensor
&
out
,
const
tensor
&
lhs
,
const
tensor
&
rhs
);
void
scale_rows
(
tensor
&
out
,
const
tensor
&
m
,
const
tensor
&
v
);
void
scale_rows2
(
float
beta
,
tensor
&
out
,
const
tensor
&
m1
,
const
tensor
&
m2
,
const
tensor
&
v1
,
const
tensor
&
v2
);
// ------------------------------------------------------------------------------------
void
multiply
(
bool
add_to
,
tensor
&
dest
,
...
...
dlib/dnn/layers.h
View file @
b2faad0d
...
...
@@ -2277,6 +2277,98 @@ namespace dlib
typename
SUBNET
>
using
inception5
=
concat5
<
itag1
,
itag2
,
itag3
,
itag4
,
itag5
,
itag1
<
B1
<
iskip
<
itag2
<
B2
<
iskip
<
itag3
<
B3
<
iskip
<
itag4
<
B4
<
iskip
<
itag5
<
B5
<
itag0
<
SUBNET
>>>>>>>>>>>>>>>>
;
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
const
double
DEFAULT_L2_NORM_EPS
=
1e-5
;
class
l2normalize_
{
public
:
explicit
l2normalize_
(
double
eps_
=
DEFAULT_L2_NORM_EPS
)
:
eps
(
eps_
)
{
}
double
get_eps
()
const
{
return
eps
;
}
template
<
typename
SUBNET
>
void
setup
(
const
SUBNET
&
/*sub*/
)
{
}
void
forward_inplace
(
const
tensor
&
input
,
tensor
&
output
)
{
tt
::
inverse_norms
(
norm
,
input
,
eps
);
tt
::
scale_rows
(
output
,
input
,
norm
);
}
void
backward_inplace
(
const
tensor
&
computed_output
,
const
tensor
&
gradient_input
,
tensor
&
data_grad
,
tensor
&
/*params_grad*/
)
{
if
(
is_same_object
(
gradient_input
,
data_grad
))
{
tt
::
dot_prods
(
temp
,
gradient_input
,
computed_output
);
tt
::
scale_rows2
(
0
,
data_grad
,
gradient_input
,
computed_output
,
temp
,
norm
);
}
else
{
tt
::
dot_prods
(
temp
,
gradient_input
,
computed_output
);
tt
::
scale_rows2
(
1
,
data_grad
,
gradient_input
,
computed_output
,
temp
,
norm
);
}
}
const
tensor
&
get_layer_params
()
const
{
return
params
;
}
tensor
&
get_layer_params
()
{
return
params
;
}
friend
void
serialize
(
const
l2normalize_
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"l2normalize_"
,
out
);
serialize
(
item
.
eps
,
out
);
}
friend
void
deserialize
(
l2normalize_
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"l2normalize_"
)
throw
serialization_error
(
"Unexpected version '"
+
version
+
"' found while deserializing dlib::l2normalize_."
);
deserialize
(
item
.
eps
,
in
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
l2normalize_
&
item
)
{
out
<<
"l2normalize"
;
out
<<
" eps="
<<
item
.
eps
;
return
out
;
}
friend
void
to_xml
(
const
l2normalize_
&
item
,
std
::
ostream
&
out
)
{
out
<<
"<l2normalize"
;
out
<<
" eps='"
<<
item
.
eps
<<
"'"
;
out
<<
"/>
\n
"
;
}
private
:
double
eps
;
resizable_tensor
params
;
// unused
// Here only to avoid reallocation and as a cache between forward/backward
// functions.
resizable_tensor
norm
;
resizable_tensor
temp
;
};
template
<
typename
SUBNET
>
using
l2normalize
=
add_layer
<
l2normalize_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
}
...
...
dlib/dnn/layers_abstract.h
View file @
b2faad0d
...
...
@@ -1930,6 +1930,56 @@ namespace dlib
using
inception5
=
concat5
<
itag1
,
itag2
,
itag3
,
itag4
,
itag5
,
itag1
<
B1
<
iskip
<
itag2
<
B2
<
iskip
<
itag3
<
B3
<
iskip
<
itag4
<
B4
<
iskip
<
itag5
<
B5
<
itag0
<
SUBNET
>>>>>>>>>>>>>>>>
;
// ----------------------------------------------------------------------------------------
const
double
DEFAULT_L2_NORM_EPS
=
1e-5
;
class
l2normalize_
{
/*!
WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
defined above. It takes tensors as input and L2 normalizes them. In particular,
it has the following properties:
- The output tensors from this layer have the same dimenstions as the
input tensors.
- If you think of each input tensor as a set of tensor::num_samples()
vectors, then the output tensor contains the same vectors except they
have been length normlized so that their L2 norms are all 1. I.e.
for each vector v we will have ||v||==1.
!*/
public
:
explicit
l2normalize_
(
double
eps
=
tt
::
DEFAULT_L2_NORM_EPS
);
/*!
requires
- eps > 0
ensures
- #get_eps() == eps
!*/
double
get_eps
(
)
const
;
/*!
ensures
- When we normalize a vector we divide it by its L2 norm. However, the
get_eps() value is added to the squared norm prior to division to avoid
ever dividing by zero.
!*/
template
<
typename
SUBNET
>
void
setup
(
const
SUBNET
&
sub
);
void
forward_inplace
(
const
tensor
&
input
,
tensor
&
output
);
void
backward_inplace
(
const
tensor
&
computed_output
,
const
tensor
&
gradient_input
,
tensor
&
data_grad
,
tensor
&
params_grad
);
const
tensor
&
get_layer_params
()
const
;
tensor
&
get_layer_params
();
/*!
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
!*/
};
// ----------------------------------------------------------------------------------------
}
...
...
dlib/dnn/tensor_tools.cpp
View file @
b2faad0d
...
...
@@ -41,6 +41,74 @@ namespace dlib
namespace
dlib
{
namespace
tt
{
// ----------------------------------------------------------------------------------------
void
inverse_norms
(
resizable_tensor
&
invnorms
,
const
tensor
&
data
,
const
double
eps
)
{
#ifdef DLIB_USE_CUDA
cuda
::
inverse_norms
(
invnorms
,
data
,
eps
);
#else
invnorms
=
reciprocal
(
sqrt
(
sum_cols
(
squared
(
mat
(
data
)))
+
eps
));
#endif
}
void
dot_prods
(
resizable_tensor
&
out
,
const
tensor
&
lhs
,
const
tensor
&
rhs
)
{
#ifdef DLIB_USE_CUDA
cuda
::
dot_prods
(
out
,
lhs
,
rhs
);
#else
out
=
sum_cols
(
pointwise_multiply
(
mat
(
lhs
),
mat
(
rhs
)));
#endif
}
void
scale_rows
(
tensor
&
out
,
const
tensor
&
m
,
const
tensor
&
v
)
{
DLIB_CASSERT
(
have_same_dimensions
(
out
,
m
));
#ifdef DLIB_USE_CUDA
cuda
::
scale_rows
(
out
,
m
,
v
);
#else
out
=
scale_rows
(
mat
(
m
),
mat
(
v
));
#endif
}
void
scale_rows2
(
float
beta
,
tensor
&
out
,
const
tensor
&
m1
,
const
tensor
&
m2
,
const
tensor
&
v1
,
const
tensor
&
v2
)
{
DLIB_CASSERT
(
have_same_dimensions
(
out
,
m1
));
DLIB_CASSERT
(
have_same_dimensions
(
out
,
m2
));
DLIB_CASSERT
(
have_same_dimensions
(
v1
,
v2
));
DLIB_CASSERT
(
is_vector
(
mat
(
v1
)));
DLIB_CASSERT
(
v1
.
size
()
==
m1
.
num_samples
());
#ifdef DLIB_USE_CUDA
cuda
::
scale_rows2
(
beta
,
out
,
m1
,
m2
,
v1
,
v2
);
#else
if
(
beta
==
0
)
out
=
scale_rows
(
mat
(
m1
)
-
scale_rows
(
mat
(
m2
),
mat
(
v1
)),
mat
(
v2
));
else
out
=
beta
*
mat
(
out
)
+
scale_rows
(
mat
(
m1
)
-
scale_rows
(
mat
(
m2
),
mat
(
v1
)),
mat
(
v2
));
#endif
}
// ----------------------------------------------------------------------------------------
void
gemm
(
...
...
dlib/dnn/tensor_tools.h
View file @
b2faad0d
...
...
@@ -22,6 +22,66 @@ namespace dlib
namespace
dlib
{
namespace
tt
{
// ----------------------------------------------------------------------------------------
void
inverse_norms
(
resizable_tensor
&
invnorms
,
const
tensor
&
data
,
const
double
eps
);
/*!
ensures
- #invnorms == reciprocal(sqrt(sum_cols(squared(mat(data))) + eps))
!*/
void
dot_prods
(
resizable_tensor
&
out
,
const
tensor
&
lhs
,
const
tensor
&
rhs
);
/*!
requires
- have_same_dimensions(lhs,rhs) == true
ensures
- #out.num_samples() == lhs.num_samples()
- #out.k() == #out.nr() == #out.nc() == 1
- #out == sum_cols(pointwise_multiply(mat(lhs), mat(rhs)));
!*/
void
scale_rows
(
tensor
&
out
,
const
tensor
&
m
,
const
tensor
&
v
);
/*!
requires
- have_same_dimensions(out,m) == true
- is_vector(mat(v)) == true
- v.size() == m.num_samples()
ensures
- performs: out = scale_rows(mat(m),mat(v));
!*/
void
scale_rows2
(
float
beta
,
tensor
&
out
,
const
tensor
&
m1
,
const
tensor
&
m2
,
const
tensor
&
v1
,
const
tensor
&
v2
);
/*!
requires
- have_same_dimensions(out,m1) == true
- have_same_dimensions(out,m2) == true
- have_same_dimensions(v1,v2) == true
- is_vector(mat(v1)) == true
- v1.size() == m1.num_samples()
ensures
- performs:
out = beta*out + scale_rows(mat(m1) - scale_rows(mat(m2),mat(v1)), mat(v2));
!*/
// ----------------------------------------------------------------------------------------
void
gemm
(
...
...
dlib/test/dnn.cpp
View file @
b2faad0d
...
...
@@ -927,7 +927,6 @@ namespace
A
=
dest
;
B
=
dest
;
tensor_rand
rnd
;
rnd
.
fill_uniform
(
dest
);
rnd
.
fill_uniform
(
A
);
rnd
.
fill_uniform
(
B
);
...
...
@@ -960,6 +959,23 @@ namespace
cuda
::
multiply
(
false
,
dest
,
A
,
B
);
DLIB_TEST
(
max
(
abs
(
mat
(
dest
)
-
pointwise_multiply
(
AA
,
mat
(
B
))))
<
1e-6
);
}
{
resizable_tensor
invnorms1
,
invnorms2
;
resizable_tensor
data
(
4
,
5
),
out1
,
out2
;
rnd
.
fill_uniform
(
data
);
const
double
eps
=
0.1
;
invnorms2
=
reciprocal
(
sqrt
(
sum_cols
(
squared
(
mat
(
data
)))
+
eps
));
tt
::
inverse_norms
(
invnorms1
,
data
,
eps
);
DLIB_TEST
(
max
(
abs
(
mat
(
invnorms1
)
-
mat
(
invnorms2
)))
<
1e-6
);
out1
.
copy_size
(
data
);
tt
::
scale_rows
(
out1
,
data
,
invnorms1
);
out2
=
scale_rows
(
mat
(
data
),
mat
(
invnorms1
));
DLIB_TEST
(
max
(
abs
(
mat
(
out1
)
-
mat
(
out2
)))
<
1e-6
);
}
}
// ----------------------------------------------------------------------------------------
...
...
@@ -1302,6 +1318,12 @@ namespace
void
test_layers
()
{
{
print_spinner
();
l2normalize_
l
;
auto
res
=
test_layer
(
l
);
DLIB_TEST_MSG
(
res
,
res
);
}
{
print_spinner
();
multiply_
l
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment