Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
c49631dd
Commit
c49631dd
authored
Oct 20, 2015
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Changed tensor layout from NHWC to NCHW since this layout is much faster for
cuDNN.
parent
30c4ee54
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
57 additions
and
52 deletions
+57
-52
core.h
dlib/dnn/core.h
+6
-5
cudnn_dlibapi.cpp
dlib/dnn/cudnn_dlibapi.cpp
+9
-9
cudnn_dlibapi.h
dlib/dnn/cudnn_dlibapi.h
+6
-6
input.h
dlib/dnn/input.h
+16
-4
layers.h
dlib/dnn/layers.h
+1
-1
tensor.h
dlib/dnn/tensor.h
+19
-27
No files found.
dlib/dnn/core.h
View file @
c49631dd
...
...
@@ -1362,7 +1362,7 @@ namespace dlib
namespace
timpl
{
void
fill_with_gassuan_random_numbers
(
inline
void
fill_with_gassuan_random_numbers
(
tensor
&
t
,
dlib
::
rand
&
rnd
,
double
sigma
=
1
...
...
@@ -1383,12 +1383,12 @@ namespace dlib
// Output and gradient_input have to have the same dimensions in each
// layer.
const
long
num_samples
=
rnd
.
get_random_32bit_number
()
%
4
+
3
;
const
long
k
=
rnd
.
get_random_32bit_number
()
%
4
+
2
;
const
long
nr
=
rnd
.
get_random_32bit_number
()
%
4
+
2
;
const
long
nc
=
rnd
.
get_random_32bit_number
()
%
4
+
2
;
const
long
k
=
rnd
.
get_random_32bit_number
()
%
4
+
2
;
output
.
set_size
(
num_samples
,
nr
,
nc
,
k
);
gradient_input
.
set_size
(
num_samples
,
nr
,
nc
,
k
);
output
.
set_size
(
num_samples
,
k
,
nr
,
nc
);
gradient_input
.
set_size
(
num_samples
,
k
,
nr
,
nc
);
// Use a non-zero initial gradient to make sure the layers add to it
// rather than assign and blow away the initial value.
...
...
@@ -1447,7 +1447,8 @@ namespace dlib
};
void
print_tensor
(
// TODO, remove?
inline
void
print_tensor
(
const
tensor
&
a
)
{
...
...
dlib/dnn/cudnn_dlibapi.cpp
View file @
c49631dd
...
...
@@ -85,9 +85,9 @@ namespace dlib
void
tensor_descriptor
::
set_size
(
int
n
,
int
k
,
int
nr
,
int
nc
,
int
k
int
nc
)
{
if
(
n
==
0
||
nr
==
0
||
nc
==
0
||
k
==
0
)
...
...
@@ -105,7 +105,7 @@ namespace dlib
handle
=
h
;
check
(
cudnnSetTensor4dDescriptor
((
cudnnTensorDescriptor_t
)
handle
,
CUDNN_TENSOR_N
HWC
,
CUDNN_TENSOR_N
CHW
,
CUDNN_DATA_FLOAT
,
n
,
k
,
...
...
@@ -117,9 +117,9 @@ namespace dlib
void
tensor_descriptor
::
get_size
(
int
&
n
,
int
&
nr
,
int
&
n
c
,
int
&
k
int
&
k
,
int
&
n
r
,
int
&
nc
)
const
{
if
(
handle
)
...
...
@@ -140,9 +140,9 @@ namespace dlib
else
{
n
=
0
;
k
=
0
;
nr
=
0
;
nc
=
0
;
k
=
0
;
}
}
...
...
@@ -254,7 +254,7 @@ namespace dlib
&
out_nc
));
tensor_descriptor
dest_desc
;
dest_desc
.
set_size
(
out_num_samples
,
out_
nr
,
out_nc
,
out_k
);
dest_desc
.
set_size
(
out_num_samples
,
out_
k
,
out_nr
,
out_nc
);
cudnnConvolutionFwdAlgo_t
forward_best_algo
;
check
(
cudnnGetConvolutionForwardAlgorithm
(
...
...
@@ -299,7 +299,7 @@ namespace dlib
const
tensor
&
filters
)
{
output
.
set_size
(
out_num_samples
,
out_
nr
,
out_nc
,
out_k
);
output
.
set_size
(
out_num_samples
,
out_
k
,
out_nr
,
out_nc
);
// TODO, remove
...
...
dlib/dnn/cudnn_dlibapi.h
View file @
c49631dd
...
...
@@ -37,9 +37,9 @@ namespace dlib
void
set_size
(
int
n
,
int
k
,
int
nr
,
int
nc
,
int
k
int
nc
);
/*!
ensures
...
...
@@ -48,9 +48,9 @@ namespace dlib
void
get_size
(
int
&
n
,
int
&
nr
,
int
&
n
c
,
int
&
k
int
&
k
,
int
&
n
r
,
int
&
nc
)
const
;
const
void
*
get_handle
(
...
...
@@ -209,9 +209,9 @@ namespace dlib
// dimensions of the output tensor from operator()
int
out_num_samples
;
int
out_k
;
int
out_nr
;
int
out_nc
;
int
out_k
;
int
forward_algo
;
size_t
forward_workspace_size_in_bytes
;
...
...
dlib/dnn/input.h
View file @
c49631dd
...
...
@@ -56,8 +56,9 @@ namespace dlib
// initialize data to the right size to contain the stuff in the iterator range.
data
.
set_size
(
std
::
distance
(
ibegin
,
iend
),
nr
,
nc
,
pixel_traits
<
T
>::
num
);
data
.
set_size
(
std
::
distance
(
ibegin
,
iend
),
pixel_traits
<
T
>::
num
,
nr
,
nc
);
const
size_t
offset
=
nr
*
nc
;
auto
ptr
=
data
.
host
();
for
(
auto
i
=
ibegin
;
i
!=
iend
;
++
i
)
{
...
...
@@ -66,10 +67,15 @@ namespace dlib
for
(
long
c
=
0
;
c
<
nc
;
++
c
)
{
auto
temp
=
pixel_to_vector
<
float
>
((
*
i
)(
r
,
c
));
auto
p
=
ptr
++
;
for
(
long
j
=
0
;
j
<
temp
.
size
();
++
j
)
*
ptr
++
=
temp
(
j
);
{
*
p
=
temp
(
j
);
p
+=
offset
;
}
}
}
ptr
+=
offset
*
(
data
.
k
()
-
1
);
}
}
...
...
@@ -123,8 +129,9 @@ namespace dlib
// initialize data to the right size to contain the stuff in the iterator range.
data
.
set_size
(
std
::
distance
(
ibegin
,
iend
),
nr
,
nc
,
pixel_traits
<
T
>::
num
);
data
.
set_size
(
std
::
distance
(
ibegin
,
iend
),
pixel_traits
<
T
>::
num
,
nr
,
nc
);
const
size_t
offset
=
nr
*
nc
;
auto
ptr
=
data
.
host
();
for
(
auto
i
=
ibegin
;
i
!=
iend
;
++
i
)
{
...
...
@@ -133,10 +140,15 @@ namespace dlib
for
(
long
c
=
0
;
c
<
nc
;
++
c
)
{
auto
temp
=
pixel_to_vector
<
float
>
((
*
i
)[
r
][
c
]);
auto
p
=
ptr
++
;
for
(
long
j
=
0
;
j
<
temp
.
size
();
++
j
)
*
ptr
++
=
temp
(
j
);
{
*
p
=
temp
(
j
);
p
+=
offset
;
}
}
}
ptr
+=
offset
*
(
data
.
k
()
-
1
);
}
}
...
...
dlib/dnn/layers.h
View file @
c49631dd
...
...
@@ -83,7 +83,7 @@ namespace dlib
template
<
typename
SUBNET
>
void
forward
(
const
SUBNET
&
sub
,
resizable_tensor
&
output
)
{
output
.
set_size
(
sub
.
get_output
().
num_samples
(),
1
,
1
,
num_outputs
);
output
.
set_size
(
sub
.
get_output
().
num_samples
(),
num_outputs
);
output
=
mat
(
sub
.
get_output
())
*
mat
(
params
);
}
...
...
dlib/dnn/tensor.h
View file @
c49631dd
...
...
@@ -22,30 +22,22 @@ namespace dlib
tensor
(
)
:
m_n
(
0
),
m_
nr
(
0
),
m_nc
(
0
),
m_k
(
0
)
m_n
(
0
),
m_
k
(
0
),
m_nr
(
0
),
m_nc
(
0
)
{
}
inline
virtual
~
tensor
()
=
0
;
long
num_samples
()
const
{
return
m_n
;
}
long
k
()
const
{
return
m_k
;
}
long
nr
()
const
{
return
m_nr
;
}
long
nc
()
const
{
return
m_nc
;
}
long
k
()
const
{
return
m_k
;
}
size_t
size
()
const
{
return
data
.
size
();
}
void
async_copy_to_device
()
{
data
.
async_copy_to_device
();
}
/*!
ensures
- begin asynchronously copying this tensor to the GPU.
NOTE that the "get device pointer" routine in this class
will have to do some kind of synchronization that ensures
the copy is finished.
!*/
const
float
*
host
()
const
{
return
data
.
host
();
}
float
*
host
()
{
return
data
.
host
();
}
...
...
@@ -135,13 +127,13 @@ namespace dlib
tensor
&
operator
=
(
const
tensor
&
item
)
{
m_n
=
item
.
m_n
;
m_k
=
item
.
m_k
;
m_nr
=
item
.
m_nr
;
m_nc
=
item
.
m_nc
;
m_k
=
item
.
m_k
;
data
.
set_size
(
item
.
data
.
size
());
std
::
memcpy
(
data
.
host
(),
item
.
data
.
host
(),
data
.
size
()
*
sizeof
(
float
));
#ifdef DLIB_USE_CUDA
cudnn_descriptor
.
set_size
(
m_n
,
m_
nr
,
m_nc
,
m_k
);
cudnn_descriptor
.
set_size
(
m_n
,
m_
k
,
m_nr
,
m_nc
);
#endif
return
*
this
;
}
...
...
@@ -159,9 +151,9 @@ namespace dlib
void
swap
(
tensor
&
item
)
{
std
::
swap
(
m_n
,
item
.
m_n
);
std
::
swap
(
m_k
,
item
.
m_k
);
std
::
swap
(
m_nr
,
item
.
m_nr
);
std
::
swap
(
m_nc
,
item
.
m_nc
);
std
::
swap
(
m_k
,
item
.
m_k
);
std
::
swap
(
data
,
item
.
data
);
#ifdef DLIB_USE_CUDA
std
::
swap
(
cudnn_descriptor
,
item
.
cudnn_descriptor
);
...
...
@@ -170,9 +162,9 @@ namespace dlib
long
m_n
;
long
m_k
;
long
m_nr
;
long
m_nc
;
long
m_k
;
gpu_data
data
;
#ifdef DLIB_USE_CUDA
cuda
::
tensor_descriptor
cudnn_descriptor
;
...
...
@@ -227,9 +219,9 @@ namespace dlib
)
{
return
a
.
num_samples
()
==
b
.
num_samples
()
&&
a
.
k
()
==
b
.
k
()
&&
a
.
nr
()
==
b
.
nr
()
&&
a
.
nc
()
==
b
.
nc
()
&&
a
.
k
()
==
b
.
k
();
a
.
nc
()
==
b
.
nc
();
}
// ----------------------------------------------------------------------------------------
...
...
@@ -242,10 +234,10 @@ namespace dlib
{}
explicit
resizable_tensor
(
long
n_
,
long
nr_
=
1
,
long
nc_
=
1
,
long
k
_
=
1
long
n_
,
long
k_
=
1
,
long
nr_
=
1
,
long
nc
_
=
1
)
{
set_size
(
n_
,
nr_
,
nc_
,
k
_
);
set_size
(
n_
,
k_
,
nr_
,
nc
_
);
}
resizable_tensor
(
const
resizable_tensor
&
)
=
default
;
...
...
@@ -265,7 +257,7 @@ namespace dlib
- resizes *this so that: have_same_dimensions(#*this, item)==true
!*/
{
set_size
(
item
.
num_samples
(),
item
.
nr
(),
item
.
nc
(),
item
.
k
());
set_size
(
item
.
num_samples
(),
item
.
k
(),
item
.
nr
(),
item
.
nc
());
}
resizable_tensor
&
operator
=
(
float
val
)
...
...
@@ -323,16 +315,16 @@ namespace dlib
}
void
set_size
(
long
n_
,
long
nr_
=
1
,
long
nc_
=
1
,
long
k
_
=
1
long
n_
,
long
k_
=
1
,
long
nr_
=
1
,
long
nc
_
=
1
)
{
m_n
=
n_
;
m_k
=
k_
;
m_nr
=
nr_
;
m_nc
=
nc_
;
m_k
=
k_
;
data
.
set_size
(
m_n
*
m_nr
*
m_nc
*
m_k
);
data
.
set_size
(
m_n
*
m_k
*
m_nr
*
m_nc
);
#ifdef DLIB_USE_CUDA
cudnn_descriptor
.
set_size
(
m_n
,
m_
nr
,
m_nc
,
m_k
);
cudnn_descriptor
.
set_size
(
m_n
,
m_
k
,
m_nr
,
m_nc
);
#endif
}
};
...
...
@@ -342,9 +334,9 @@ namespace dlib
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
num_samples
(),
out
);
serialize
(
item
.
k
(),
out
);
serialize
(
item
.
nr
(),
out
);
serialize
(
item
.
nc
(),
out
);
serialize
(
item
.
k
(),
out
);
auto
data
=
item
.
host
();
for
(
size_t
i
=
0
;
i
<
item
.
size
();
++
i
)
serialize
(
data
[
i
],
out
);
...
...
@@ -357,12 +349,12 @@ namespace dlib
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::resizable_tensor."
);
long
num_samples
=
0
,
nr
=
0
,
nc
=
0
,
k
=
0
;
long
num_samples
=
0
,
k
=
0
,
nr
=
0
,
nc
=
0
;
deserialize
(
num_samples
,
in
);
deserialize
(
k
,
in
);
deserialize
(
nr
,
in
);
deserialize
(
nc
,
in
);
deserialize
(
k
,
in
);
item
.
set_size
(
num_samples
,
nr
,
nc
,
k
);
item
.
set_size
(
num_samples
,
k
,
nr
,
nc
);
auto
data
=
item
.
host
();
for
(
size_t
i
=
0
;
i
<
item
.
size
();
++
i
)
deserialize
(
data
[
i
],
in
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment