Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
cbce85ec
Commit
cbce85ec
authored
Dec 05, 2015
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added GPU versions of the batch normalization functions.
parent
06534305
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
4 deletions
+107
-4
cpu_dlib.cpp
dlib/dnn/cpu_dlib.cpp
+4
-4
cuda_dlib.cu
dlib/dnn/cuda_dlib.cu
+0
-0
dnn.cpp
dlib/test/dnn.cpp
+103
-0
No files found.
dlib/dnn/cpu_dlib.cpp
View file @
cbce85ec
...
@@ -185,7 +185,7 @@ namespace dlib
...
@@ -185,7 +185,7 @@ namespace dlib
for
(
long
i
=
0
;
i
<
num
;
++
i
)
for
(
long
i
=
0
;
i
<
num
;
++
i
)
{
{
auto
actual_var
=
p_invstds
[
i
]
-
p_means
[
i
]
*
p_means
[
i
];
auto
actual_var
=
p_invstds
[
i
]
-
p_means
[
i
]
*
p_means
[
i
];
p_invstds
[
i
]
=
1.0
/
std
::
sqrt
(
actual_var
+
eps
);
p_invstds
[
i
]
=
1.0
f
/
std
::
sqrt
(
actual_var
+
eps
);
}
}
p_src
=
src
.
host
();
p_src
=
src
.
host
();
...
@@ -361,8 +361,8 @@ namespace dlib
...
@@ -361,8 +361,8 @@ namespace dlib
// compute variances
// compute variances
for
(
long
k
=
0
;
k
<
src
.
k
();
++
k
)
for
(
long
k
=
0
;
k
<
src
.
k
();
++
k
)
{
{
auto
actual_var
=
p_invstds
[
k
]
-
p_means
[
k
]
*
p_means
[
k
];
float
actual_var
=
p_invstds
[
k
]
-
p_means
[
k
]
*
p_means
[
k
];
p_invstds
[
k
]
=
1.0
/
std
::
sqrt
(
actual_var
+
eps
);
p_invstds
[
k
]
=
1.0
f
/
std
::
sqrt
(
actual_var
+
eps
);
}
}
p_src
=
src
.
host
();
p_src
=
src
.
host
();
...
@@ -421,7 +421,7 @@ namespace dlib
...
@@ -421,7 +421,7 @@ namespace dlib
{
{
for
(
long
k
=
0
;
k
<
src
.
k
();
++
k
)
for
(
long
k
=
0
;
k
<
src
.
k
();
++
k
)
{
{
const
auto
invstd_pow
=
-
0.5
*
std
::
pow
(
p_invstds
[
k
],
3.0
f
);
const
float
invstd_pow
=
-
0.5
*
std
::
pow
(
p_invstds
[
k
],
3.0
f
);
for
(
long
i
=
0
;
i
<
num
;
++
i
)
for
(
long
i
=
0
;
i
<
num
;
++
i
)
{
{
const
float
x_hat
=
(
*
p_src
-
p_means
[
k
])
*
p_invstds
[
k
];
const
float
x_hat
=
(
*
p_src
-
p_means
[
k
])
*
p_invstds
[
k
];
...
...
dlib/dnn/cuda_dlib.cu
View file @
cbce85ec
This diff is collapsed.
Click to expand it.
dlib/test/dnn.cpp
View file @
cbce85ec
...
@@ -460,6 +460,107 @@ namespace
...
@@ -460,6 +460,107 @@ namespace
}
}
#endif
#endif
// ----------------------------------------------------------------------------------------
void
compare_bn_gpu_and_cpu
()
{
print_spinner
();
resizable_tensor
dest
,
dest2
;
resizable_tensor
means
,
means2
;
resizable_tensor
invstds
,
invstds2
;
resizable_tensor
src
(
64
,
20
,
100
,
100
);
resizable_tensor
gamma
(
1
,
20
,
100
,
100
);
resizable_tensor
beta
(
1
,
20
,
100
,
100
);
gamma
=
2
;
beta
=
3
;
tt
::
tensor_rand
rnd
;
rnd
.
fill_uniform
(
src
);
cpu
::
batch_normalize
(
dest
,
means
,
invstds
,
src
,
gamma
,
beta
);
cuda
::
batch_normalize
(
dest2
,
means2
,
invstds2
,
src
,
gamma
,
beta
);
dlog
<<
LINFO
<<
"dest error: "
<<
max
(
abs
(
mat
(
dest
)
-
mat
(
dest2
)));
dlog
<<
LINFO
<<
"means error: "
<<
max
(
abs
(
mat
(
means
)
-
mat
(
means2
)));
dlog
<<
LINFO
<<
"invstds error: "
<<
max
(
abs
(
mat
(
invstds
)
-
mat
(
invstds2
)));
DLIB_TEST
(
max
(
abs
(
mat
(
dest
)
-
mat
(
dest2
)))
<
1e-5
);
DLIB_TEST
(
max
(
abs
(
mat
(
means
)
-
mat
(
means2
)))
<
1e-5
);
DLIB_TEST
(
max
(
abs
(
mat
(
invstds
)
-
mat
(
invstds2
)))
<
1e-5
);
// now check that the gradients match as well
resizable_tensor
gradient_input
;
resizable_tensor
src_grad
,
gamma_grad
,
beta_grad
;
resizable_tensor
src_grad2
,
gamma_grad2
,
beta_grad2
;
gradient_input
.
copy_size
(
dest
);
src_grad
.
copy_size
(
src
);
src_grad
=
0
;
src_grad2
=
src_grad
;
gamma_grad
.
copy_size
(
gamma
);
gamma_grad
=
0
;
gamma_grad2
=
gamma_grad
;
beta_grad
.
copy_size
(
beta
);
beta_grad
=
0
;
beta_grad2
=
beta_grad
;
rnd
.
fill_uniform
(
gradient_input
);
cpu
::
batch_normalize_gradient
cpu_bng
;
cpu_bng
(
gradient_input
,
means
,
invstds
,
src
,
gamma
,
src_grad
,
gamma_grad
,
beta_grad
);
cuda
::
batch_normalize_gradient
cuda_bng
;
cuda_bng
(
gradient_input
,
means
,
invstds
,
src
,
gamma
,
src_grad2
,
gamma_grad2
,
beta_grad2
);
dlog
<<
LINFO
<<
"src_grad error: "
<<
max
(
abs
(
mat
(
src_grad
)
-
mat
(
src_grad2
)));
dlog
<<
LINFO
<<
"gamma_grad error: "
<<
max
(
abs
(
mat
(
gamma_grad
)
-
mat
(
gamma_grad2
)));
dlog
<<
LINFO
<<
"beta_grad error: "
<<
max
(
abs
(
mat
(
beta_grad
)
-
mat
(
beta_grad2
)));
DLIB_TEST
(
max
(
abs
(
mat
(
src_grad
)
-
mat
(
src_grad2
)))
<
1e-5
);
DLIB_TEST
(
max
(
abs
(
mat
(
gamma_grad
)
-
mat
(
gamma_grad2
)))
<
1e-5
);
DLIB_TEST
(
max
(
abs
(
mat
(
beta_grad
)
-
mat
(
beta_grad2
)))
<
1e-5
);
}
void
compare_bn_conv_gpu_and_cpu
()
{
print_spinner
();
resizable_tensor
dest
,
dest2
;
resizable_tensor
means
,
means2
;
resizable_tensor
invstds
,
invstds2
;
resizable_tensor
src
(
2
,
8
,
10
,
9
);
resizable_tensor
gamma
(
1
,
8
);
resizable_tensor
beta
(
1
,
8
);
gamma
=
2
;
beta
=
3
;
tt
::
tensor_rand
rnd
;
rnd
.
fill_uniform
(
src
);
cpu
::
batch_normalize_conv
(
dest
,
means
,
invstds
,
src
,
gamma
,
beta
);
cuda
::
batch_normalize_conv
(
dest2
,
means2
,
invstds2
,
src
,
gamma
,
beta
);
dlog
<<
LINFO
<<
"dest error: "
<<
max
(
abs
(
mat
(
dest
)
-
mat
(
dest2
)));
dlog
<<
LINFO
<<
"means error: "
<<
max
(
abs
(
mat
(
means
)
-
mat
(
means2
)));
dlog
<<
LINFO
<<
"invstds error: "
<<
max
(
abs
(
mat
(
invstds
)
-
mat
(
invstds2
)));
DLIB_TEST
(
max
(
abs
(
mat
(
dest
)
-
mat
(
dest2
)))
<
1e-4
);
DLIB_TEST
(
max
(
abs
(
mat
(
means
)
-
mat
(
means2
)))
<
1e-4
);
DLIB_TEST
(
max
(
abs
(
mat
(
invstds
)
-
mat
(
invstds2
)))
<
1e-4
);
resizable_tensor
gradient_input
;
resizable_tensor
src_grad
,
gamma_grad
,
beta_grad
;
resizable_tensor
src_grad2
,
gamma_grad2
,
beta_grad2
;
gradient_input
.
copy_size
(
dest
);
src_grad
.
copy_size
(
src
);
src_grad
=
0
;
src_grad2
=
src_grad
;
gamma_grad
.
copy_size
(
gamma
);
gamma_grad
=
0
;
gamma_grad2
=
gamma_grad
;
beta_grad
.
copy_size
(
beta
);
beta_grad
=
0
;
beta_grad2
=
beta_grad
;
rnd
.
fill_uniform
(
gradient_input
);
cpu
::
batch_normalize_conv_gradient
cpu_bng
;
cpu_bng
(
gradient_input
,
means
,
invstds
,
src
,
gamma
,
src_grad
,
gamma_grad
,
beta_grad
);
cuda
::
batch_normalize_conv_gradient
cuda_bng
;
cuda_bng
(
gradient_input
,
means
,
invstds
,
src
,
gamma
,
src_grad2
,
gamma_grad2
,
beta_grad2
);
dlog
<<
LINFO
<<
"src_grad error: "
<<
max
(
abs
(
mat
(
src_grad
)
-
mat
(
src_grad2
)));
dlog
<<
LINFO
<<
"gamma_grad error: "
<<
max
(
abs
(
mat
(
gamma_grad
)
-
mat
(
gamma_grad2
)));
dlog
<<
LINFO
<<
"beta_grad error: "
<<
max
(
abs
(
mat
(
beta_grad
)
-
mat
(
beta_grad2
)));
DLIB_TEST
(
max
(
abs
(
mat
(
src_grad
)
-
mat
(
src_grad2
)))
<
1e-4
);
DLIB_TEST
(
max
(
abs
(
mat
(
gamma_grad
)
-
mat
(
gamma_grad2
)))
<
1e-4
);
DLIB_TEST
(
max
(
abs
(
mat
(
beta_grad
)
-
mat
(
beta_grad2
)))
<
1e-4
);
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class
dnn_tester
:
public
tester
class
dnn_tester
:
public
tester
...
@@ -488,6 +589,8 @@ namespace
...
@@ -488,6 +589,8 @@ namespace
test_batch_normalize
();
test_batch_normalize
();
test_batch_normalize_conv
();
test_batch_normalize_conv
();
test_basic_tensor_ops
();
test_basic_tensor_ops
();
compare_bn_gpu_and_cpu
();
compare_bn_conv_gpu_and_cpu
();
}
}
}
a
;
}
a
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment