Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
1f0318e2
Commit
1f0318e2
authored
May 26, 2016
by
Fm
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
depth_group replaced with concat layer
parent
93e786db
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
546 additions
and
182 deletions
+546
-182
core.h
dlib/dnn/core.h
+0
-0
cpu_dlib.cpp
dlib/dnn/cpu_dlib.cpp
+23
-44
cpu_dlib.h
dlib/dnn/cpu_dlib.h
+8
-13
cuda_dlib.cu
dlib/dnn/cuda_dlib.cu
+16
-35
cuda_dlib.h
dlib/dnn/cuda_dlib.h
+5
-9
layers.h
dlib/dnn/layers.h
+157
-0
layers_abstract.h
dlib/dnn/layers_abstract.h
+82
-0
tensor_tools.cpp
dlib/dnn/tensor_tools.cpp
+13
-16
tensor_tools.h
dlib/dnn/tensor_tools.h
+19
-35
dnn.cpp
dlib/test/dnn.cpp
+189
-1
dnn_inception_ex.cpp
examples/dnn_inception_ex.cpp
+34
-29
No files found.
dlib/dnn/core.h
View file @
1f0318e2
This diff is collapsed.
Click to expand it.
dlib/dnn/cpu_dlib.cpp
View file @
1f0318e2
...
...
@@ -1783,58 +1783,37 @@ namespace dlib
filters_gradient
+=
gi
*
temp
;
}
}
// ------------------------------------------------------------------------------------
void
concat_depth
(
// ------------------------------------------------------------------------------------
void
copy_tensor
(
tensor
&
dest
,
size_t
sample_offset
,
const
tensor
&
src
)
{
const
size_t
dest_sample_size
=
static_cast
<
size_t
>
(
dest
.
nc
()
*
dest
.
nr
()
*
dest
.
k
());
const
size_t
src_sample_size
=
static_cast
<
size_t
>
(
src
.
nc
()
*
src
.
nr
()
*
src
.
k
());
DLIB_CASSERT
(
dest
.
num_samples
()
==
src
.
num_samples
()
&&
dest
.
nc
()
==
src
.
nc
()
&&
dest
.
nr
()
==
src
.
nr
(),
"All sources should fit into dest tensor size"
);
DLIB_CASSERT
(
dest_sample_size
>=
src_sample_size
+
sample_offset
,
"Not enough space in dest tensor"
);
size_t
dest_k_offset
,
const
tensor
&
src
,
size_t
src_k_offset
,
size_t
count_k
)
{
const
size_t
dest_sample_size
=
static_cast
<
size_t
>
(
dest
.
nc
()
*
dest
.
nr
()
*
dest
.
k
());
const
size_t
src_sample_size
=
static_cast
<
size_t
>
(
src
.
nc
()
*
src
.
nr
()
*
src
.
k
());
float
*
dest_p
=
dest
.
host_write_only
()
+
sample_offset
;
const
float
*
src_p
=
src
.
host
();
const
size_t
block_size
=
count_k
*
dest
.
nc
()
*
dest
.
nr
();
for
(
unsigned
long
i
=
0
;
i
<
src
.
num_samples
();
++
i
)
{
::
memcpy
(
dest_p
,
src_p
,
src_sample_size
*
sizeof
(
float
));
DLIB_CASSERT
(
dest
.
num_samples
()
==
src
.
num_samples
()
&&
dest
.
nc
()
==
src
.
nc
()
&&
dest
.
nr
()
==
src
.
nr
(),
"All sources should fit into dest tensor size"
);
DLIB_CASSERT
(
dest
.
k
()
-
dest_k_offset
>=
count_k
,
"Not enough space in dest tensor"
);
DLIB_CASSERT
(
src
.
k
()
-
src_k_offset
>=
count_k
,
"Not enough space in src tensor"
);
dest_p
+=
dest_sample_size
;
src_p
+=
src_sample_size
;
}
}
float
*
dest_p
=
dest
.
host
()
+
dest_k_offset
*
dest
.
nc
()
*
dest
.
nr
();
const
float
*
src_p
=
src
.
host
()
+
src_k_offset
*
src
.
nc
()
*
src
.
nr
();
void
split_depth
(
tensor
&
dest
,
size_t
sample_offset
,
const
tensor
&
src
)
for
(
unsigned
long
i
=
0
;
i
<
src
.
num_samples
();
++
i
)
{
const
size_t
dest_sample_size
=
static_cast
<
size_t
>
(
dest
.
nc
()
*
dest
.
nr
()
*
dest
.
k
());
const
size_t
src_sample_size
=
static_cast
<
size_t
>
(
src
.
nc
()
*
src
.
nr
()
*
src
.
k
());
::
memcpy
(
dest_p
,
src_p
,
block_size
*
sizeof
(
float
));
DLIB_CASSERT
(
dest
.
num_samples
()
==
src
.
num_samples
()
&&
dest
.
nc
()
==
src
.
nc
()
&&
dest
.
nr
()
==
src
.
nr
(),
"All sources should fit into dest tensor size"
);
DLIB_CASSERT
(
dest_sample_size
<=
src_sample_size
-
sample_offset
,
"Not enough space in dest tensor"
);
float
*
dest_p
=
dest
.
host_write_only
();
const
float
*
src_p
=
src
.
host
()
+
sample_offset
;
for
(
unsigned
long
i
=
0
;
i
<
src
.
num_samples
();
++
i
)
{
::
memcpy
(
dest_p
,
src_p
,
dest_sample_size
*
sizeof
(
float
));
dest_p
+=
dest_sample_size
;
src_p
+=
src_sample_size
;
}
dest_p
+=
dest_sample_size
;
src_p
+=
src_sample_size
;
}
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
...
...
dlib/dnn/cpu_dlib.h
View file @
1f0318e2
...
...
@@ -384,19 +384,14 @@ namespace dlib
long
last_padding_x
;
};
// ----------------------------------------------------------------------------------------
void
concat_depth
(
tensor
&
dest
,
size_t
sample_offset
,
const
tensor
&
src
);
void
split_depth
(
tensor
&
dest
,
size_t
sample_offset
,
const
tensor
&
src
);
// -----------------------------------------------------------------------------------
void
copy_tensor
(
tensor
&
dest
,
size_t
dest_k_offset
,
const
tensor
&
src
,
size_t
src_k_offset
,
size_t
count_k
);
// -----------------------------------------------------------------------------------
}
...
...
dlib/dnn/cuda_dlib.cu
View file @
1f0318e2
...
...
@@ -796,57 +796,38 @@ namespace dlib
grad.device(), src.device(), gradient_input.device(), grad.size(),
param.device(), params_grad.device());
}
// ----------------------------------------------------------------------------------------
void concat_depth(
tensor& dest,
size_t sample_offset,
const tensor& src
// ----------------------------------------------------------------------------------------
void copy_tensor(
tensor& dest,
size_t dest_k_offset,
const tensor& src,
size_t src_k_offset,
size_t count_k
)
{
const size_t dest_sample_size = static_cast<size_t>(dest.nc() * dest.nr() * dest.k());
const size_t src_sample_size = static_cast<size_t>(src.nc() * src.nr() * src.k());
const size_t block_size = count_k * dest.nc() * dest.nr();
DLIB_CASSERT(dest.num_samples() == src.num_samples() &&
dest.nc() == src.nc() && dest.nr() == src.nr(), "All sources should fit into dest tensor size");
DLIB_CASSERT(dest_sample_size >= src_sample_size + sample_offset, "Not enough space in dest tensor");
DLIB_CASSERT(dest.k() - dest_k_offset >= count_k, "Not enough space in dest tensor");
DLIB_CASSERT(src.k() - src_k_offset >= count_k, "Not enough space in src tensor");
float* dest_p = dest.device() + dest_k_offset * dest.nc() * dest.nr();
const float* src_p = src.device() + src_k_offset * src.nc() * src.nr();;
float* dest_p = dest.device_write_only() + sample_offset;
const float* src_p = src.device();
for (unsigned long i = 0; i < src.num_samples(); ++i)
{
CHECK_CUDA(cudaMemcpy(dest_p, src_p,
src_sample
_size * sizeof(float), cudaMemcpyDeviceToDevice));
CHECK_CUDA(cudaMemcpy(dest_p, src_p,
block
_size * sizeof(float), cudaMemcpyDeviceToDevice));
dest_p += dest_sample_size;
src_p += src_sample_size;
}
}
void split_depth(
tensor& dest,
size_t sample_offset,
const tensor& src
)
{
const size_t dest_sample_size = static_cast<size_t>(dest.nc() * dest.nr() * dest.k());
const size_t src_sample_size = static_cast<size_t>(src.nc() * src.nr() * src.k());
DLIB_CASSERT(dest.num_samples() == src.num_samples() &&
dest.nc() == src.nc() && dest.nr() == src.nr(),
"All sources should fit into dest tensor size");
DLIB_CASSERT(dest_sample_size <= src_sample_size - sample_offset, "Not enough space in dest tensor");
float *dest_p = dest.device_write_only();
const float *src_p = src.device() + sample_offset;
for (unsigned long i = 0; i < src.num_samples(); ++i) {
CHECK_CUDA(cudaMemcpy(dest_p, src_p, dest_sample_size * sizeof(float), cudaMemcpyDeviceToDevice));
dest_p += dest_sample_size;
src_p += src_sample_size;
}
}
// ----------------------------------------------------------------------------------------
}
...
...
dlib/dnn/cuda_dlib.h
View file @
1f0318e2
...
...
@@ -258,16 +258,12 @@ namespace dlib
tensor
&
params_grad
);
void
co
ncat_depth
(
void
co
py_tensor
(
tensor
&
dest
,
size_t
sample_offset
,
const
tensor
&
src
);
void
split_depth
(
tensor
&
dest
,
size_t
sample_offset
,
const
tensor
&
src
const
tensor
&
src
,
size_t
dest_k_offset
,
size_t
src_k_offset
,
size_t
count_k
);
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
...
...
dlib/dnn/layers.h
View file @
1f0318e2
...
...
@@ -1836,6 +1836,163 @@ namespace dlib
template
<
typename
SUBNET
>
using
softmax
=
add_layer
<
softmax_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
namespace
impl
{
// helper classes for layer concat processing
template
<
template
<
typename
>
class
...
TAG_TYPES
>
struct
concat_helper_impl
{
};
template
<
template
<
typename
>
class
TAG_TYPE
>
struct
concat_helper_impl
<
TAG_TYPE
>
{
template
<
typename
SUBNET
>
static
void
resize_out
(
resizable_tensor
&
out
,
const
SUBNET
&
sub
,
long
sum_k
)
{
auto
&
t
=
layer
<
TAG_TYPE
>
(
sub
).
get_output
();
out
.
set_size
(
t
.
num_samples
(),
t
.
k
()
+
sum_k
,
t
.
nr
(),
t
.
nc
());
}
template
<
typename
SUBNET
>
static
void
concat
(
tensor
&
out
,
const
SUBNET
&
sub
,
size_t
k_offset
)
{
auto
&
t
=
layer
<
TAG_TYPE
>
(
sub
).
get_output
();
tt
::
copy_tensor
(
out
,
k_offset
,
t
,
0
,
t
.
k
());
}
template
<
typename
SUBNET
>
static
void
split
(
const
tensor
&
input
,
SUBNET
&
sub
,
size_t
k_offset
)
{
auto
&
t
=
layer
<
TAG_TYPE
>
(
sub
).
get_gradient_input
();
tt
::
copy_tensor
(
t
,
0
,
input
,
k_offset
,
t
.
k
());
}
};
template
<
template
<
typename
>
class
TAG_TYPE
,
template
<
typename
>
class
...
TAG_TYPES
>
struct
concat_helper_impl
<
TAG_TYPE
,
TAG_TYPES
...
>
{
template
<
typename
SUBNET
>
static
void
resize_out
(
resizable_tensor
&
out
,
const
SUBNET
&
sub
,
long
sum_k
)
{
auto
&
t
=
layer
<
TAG_TYPE
>
(
sub
).
get_output
();
concat_helper_impl
<
TAG_TYPES
...
>::
resize_out
(
out
,
sub
,
sum_k
+
t
.
k
());
}
template
<
typename
SUBNET
>
static
void
concat
(
tensor
&
out
,
const
SUBNET
&
sub
,
size_t
k_offset
)
{
auto
&
t
=
layer
<
TAG_TYPE
>
(
sub
).
get_output
();
tt
::
copy_tensor
(
out
,
k_offset
,
t
,
0
,
t
.
k
());
k_offset
+=
t
.
k
();
concat_helper_impl
<
TAG_TYPES
...
>::
concat
(
out
,
sub
,
k_offset
);
}
template
<
typename
SUBNET
>
static
void
split
(
const
tensor
&
input
,
SUBNET
&
sub
,
size_t
k_offset
)
{
auto
&
t
=
layer
<
TAG_TYPE
>
(
sub
).
get_gradient_input
();
tt
::
copy_tensor
(
t
,
0
,
input
,
k_offset
,
t
.
k
());
k_offset
+=
t
.
k
();
concat_helper_impl
<
TAG_TYPES
...
>::
split
(
input
,
sub
,
k_offset
);
}
};
}
// concat layer
template
<
template
<
typename
>
class
...
TAG_TYPES
>
class
concat_
{
public
:
template
<
typename
SUBNET
>
void
setup
(
const
SUBNET
&
)
{
// do nothing
}
template
<
typename
SUBNET
>
void
forward
(
const
SUBNET
&
sub
,
resizable_tensor
&
output
)
{
// the total depth of result is the sum of depths from all tags
impl
::
concat_helper_impl
<
TAG_TYPES
...
>::
resize_out
(
output
,
sub
,
0
);
// copy output from each tag into different part result
impl
::
concat_helper_impl
<
TAG_TYPES
...
>::
concat
(
output
,
sub
,
0
);
}
template
<
typename
SUBNET
>
void
backward
(
const
tensor
&
gradient_input
,
SUBNET
&
sub
,
tensor
&
)
{
// Gradient is splitted into parts for each tag layer
impl
::
concat_helper_impl
<
TAG_TYPES
...
>::
split
(
gradient_input
,
sub
,
0
);
}
const
tensor
&
get_layer_params
()
const
{
return
params
;
}
tensor
&
get_layer_params
()
{
return
params
;
}
friend
void
serialize
(
const
concat_
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"concat_"
,
out
);
serialize
(
sizeof
...(
TAG_TYPES
),
out
);
}
friend
void
deserialize
(
concat_
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"concat_"
)
throw
serialization_error
(
"Unexpected version '"
+
version
+
"' found while deserializing dlib::concat_."
);
size_t
count_tags
;
deserialize
(
count_tags
,
in
);
if
(
count_tags
!=
sizeof
...(
TAG_TYPES
))
throw
serialization_error
(
"Invalid count of tags "
+
std
::
to_string
(
count_tags
)
+
", expecting "
+
std
::
to_string
(
sizeof
...(
TAG_TYPES
))
+
" found while deserializing dlib::concat_."
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
concat_
&
item
)
{
out
<<
"concat
\t
("
<<
sizeof
...(
TAG_TYPES
)
<<
")"
;
return
out
;
}
private
:
resizable_tensor
params
;
// unused
};
template
<
typename
SUBNET
,
template
<
typename
>
class
...
TAG_TYPES
>
using
concat
=
add_layer
<
concat_
<
TAG_TYPES
...
>
,
SUBNET
>
;
// inception layer will use tags internally. If user will use tags too,
// some conflicts possible
// to exclude them, here are new tags specially for inceptions
template
<
typename
SUBNET
>
using
itag0
=
add_tag_layer
<
1000
+
0
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
itag1
=
add_tag_layer
<
1000
+
1
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
itag2
=
add_tag_layer
<
1000
+
2
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
itag3
=
add_tag_layer
<
1000
+
3
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
itag4
=
add_tag_layer
<
1000
+
4
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
itag5
=
add_tag_layer
<
1000
+
5
,
SUBNET
>
;
// skip to inception input
template
<
typename
SUBNET
>
using
iskip
=
add_skip_layer
<
itag0
,
SUBNET
>
;
// here are some templates to be used for creating inception layer groups
template
<
template
<
typename
>
class
B1
,
template
<
typename
>
class
B2
,
typename
SUBNET
>
using
inception2
=
concat
<
itag1
<
B1
<
iskip
<
itag2
<
B2
<
itag0
<
SUBNET
>>>>>>
,
itag1
,
itag2
>
;
template
<
template
<
typename
>
class
B1
,
template
<
typename
>
class
B2
,
template
<
typename
>
class
B3
,
typename
SUBNET
>
using
inception3
=
concat
<
itag1
<
B1
<
iskip
<
itag2
<
B2
<
iskip
<
itag3
<
B3
<
itag0
<
SUBNET
>>>>>>>>>
,
itag1
,
itag2
,
itag3
>
;
template
<
template
<
typename
>
class
B1
,
template
<
typename
>
class
B2
,
template
<
typename
>
class
B3
,
template
<
typename
>
class
B4
,
typename
SUBNET
>
using
inception4
=
concat
<
itag1
<
B1
<
iskip
<
itag2
<
B2
<
iskip
<
itag3
<
B3
<
iskip
<
itag4
<
B4
<
itag0
<
SUBNET
>>>>>>>>>>>>
,
itag1
,
itag2
,
itag3
,
itag4
>
;
template
<
template
<
typename
>
class
B1
,
template
<
typename
>
class
B2
,
template
<
typename
>
class
B3
,
template
<
typename
>
class
B4
,
template
<
typename
>
class
B5
,
typename
SUBNET
>
using
inception5
=
concat
<
itag1
<
B1
<
iskip
<
itag2
<
B2
<
iskip
<
itag3
<
B3
<
iskip
<
itag4
<
B4
<
iskip
<
itag5
<
B5
<
itag0
<
SUBNET
>>>>>>>>>>>>>>>
,
itag1
,
itag2
,
itag3
,
itag4
,
itag5
>
;
// ----------------------------------------------------------------------------------------
}
...
...
dlib/dnn/layers_abstract.h
View file @
1f0318e2
...
...
@@ -1652,6 +1652,88 @@ namespace dlib
using
add_prev9_
=
add_prev_
<
tag9
>
;
using
add_prev10_
=
add_prev_
<
tag10
>
;
// ----------------------------------------------------------------------------------------
template
<
template
<
typename
>
class
...
TAG_TYPES
>
class
concat_
{
/*!
WHAT THIS OBJECT REPRESENTS
This is an implementation of the EXAMPLE_COMPUTATIONAL_LAYER_ interface
defined above. This layer simply concatenates the output of requiered layers
In particular, it copies each layer's output from TAG_TYPES into the corresponding
place of the result tensor, those producing combined output
The output of each tag layer is stored in a separate part of final output.
FORWARD:
for each (tag in TAG_TYPES)
outout[i, k + tag.k(), r, c] = layer<tag>(subnet).get_output[i, k, r, c]
BACKWARD:
for each (tag in TAG_TYPES)
layer<tag>(subnet).get_gradient_input[i, k, r, c] = input[i, k + tag.k(), r, c]
This layer can be only used with tags inside.
Each tagged layer should have identical num_samples, R and C size
The output will have K = sum(k) of tags, and the, and the output's num_samples,
R and C will be the same as tagged layers
!*/
public
:
template
<
typename
SUBNET
>
void
setup
(
const
SUBNET
&
sub
);
template
<
typename
SUBNET
>
void
forward
(
const
SUBNET
&
sub
,
resizable_tensor
&
output
);
template
<
typename
SUBNET
>
void
backward
(
const
tensor
&
gradient_input
,
SUBNET
&
sub
,
tensor
&
params_grad
);
const
tensor
&
get_layer_params
()
const
;
tensor
&
get_layer_params
();
/*!
These functions are implemented as described in the EXAMPLE_COMPUTATIONAL_LAYER_ interface.
!*/
};
template
<
typename
SUBNET
,
template
<
typename
>
class
...
TAG_TYPES
>
using
concat
=
add_layer
<
concat_
<
TAG_TYPES
...
>
,
SUBNET
>
;
// inception layer will use tags internally. If user will use tags too,
// some conflicts possible
// to exclude them, here are new tags specially for inceptions
template
<
typename
SUBNET
>
using
itag0
=
add_tag_layer
<
1000
+
0
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
itag1
=
add_tag_layer
<
1000
+
1
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
itag2
=
add_tag_layer
<
1000
+
2
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
itag3
=
add_tag_layer
<
1000
+
3
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
itag4
=
add_tag_layer
<
1000
+
4
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
itag5
=
add_tag_layer
<
1000
+
5
,
SUBNET
>
;
// skip to inception input
template
<
typename
SUBNET
>
using
iskip
=
add_skip_layer
<
itag0
,
SUBNET
>
;
// here are some templates to be used for creating inception layer groups
template
<
template
<
typename
>
class
B1
,
template
<
typename
>
class
B2
,
typename
SUBNET
>
using
inception2
=
concat
<
itag1
<
B1
<
iskip
<
itag2
<
B2
<
itag0
<
SUBNET
>>>>>>
,
itag1
,
itag2
>
;
template
<
template
<
typename
>
class
B1
,
template
<
typename
>
class
B2
,
template
<
typename
>
class
B3
,
typename
SUBNET
>
using
inception3
=
concat
<
itag1
<
B1
<
iskip
<
itag2
<
B2
<
iskip
<
itag3
<
B3
<
itag0
<
SUBNET
>>>>>>>>>
,
itag1
,
itag2
,
itag3
>
;
template
<
template
<
typename
>
class
B1
,
template
<
typename
>
class
B2
,
template
<
typename
>
class
B3
,
template
<
typename
>
class
B4
,
typename
SUBNET
>
using
inception4
=
concat
<
itag1
<
B1
<
iskip
<
itag2
<
B2
<
iskip
<
itag3
<
B3
<
iskip
<
itag4
<
B4
<
itag0
<
SUBNET
>>>>>>>>>>>>
,
itag1
,
itag2
,
itag3
,
itag4
>
;
template
<
template
<
typename
>
class
B1
,
template
<
typename
>
class
B2
,
template
<
typename
>
class
B3
,
template
<
typename
>
class
B4
,
template
<
typename
>
class
B5
,
typename
SUBNET
>
using
inception5
=
concat
<
itag1
<
B1
<
iskip
<
itag2
<
B2
<
iskip
<
itag3
<
B3
<
iskip
<
itag4
<
B4
<
iskip
<
itag5
<
B5
<
itag0
<
SUBNET
>>>>>>>>>>>>>>>
,
itag1
,
itag2
,
itag3
,
itag4
,
itag5
>
;
// ----------------------------------------------------------------------------------------
}
...
...
dlib/dnn/tensor_tools.cpp
View file @
1f0318e2
...
...
@@ -678,26 +678,23 @@ namespace dlib { namespace tt
#endif
}
// ----------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
void
concat_depth
(
tensor
&
dest
,
size_t
sample_offset
,
const
tensor
&
src
)
{
// ------------------------------------------------------------------------------------
void
copy_tensor
(
tensor
&
dest
,
size_t
dest_k_offset
,
const
tensor
&
src
,
size_t
src_k_offset
,
size_t
count_k
)
{
#ifdef DLIB_USE_CUDA
cuda
::
concat_depth
(
dest
,
sample_offset
,
src
);
cuda
::
copy_tensor
(
dest
,
dest_k_offset
,
src
,
src_k_offset
,
count_k
);
#else
cpu
::
concat_depth
(
dest
,
sample_offset
,
src
);
cpu
::
copy_tensor
(
dest
,
dest_k_offset
,
src
,
src_k_offset
,
count_k
);
#endif
}
}
void
split_depth
(
tensor
&
dest
,
size_t
sample_offset
,
const
tensor
&
src
)
{
#ifdef DLIB_USE_CUDA
cuda
::
split_depth
(
dest
,
sample_offset
,
src
);
#else
cpu
::
split_depth
(
dest
,
sample_offset
,
src
);
#endif
}
// ----------------------------------------------------------------------------------------
}}
...
...
dlib/dnn/tensor_tools.h
View file @
1f0318e2
...
...
@@ -1234,41 +1234,25 @@ namespace dlib { namespace tt
};
// ----------------------------------------------------------------------------------------
void
concat_depth
(
tensor
&
dest
,
size_t
sample_offset
,
const
tensor
&
src
);
/*!
requires
- dest.nc() == src.nc()
- dest.nr() == src.nr()
- dest.num_samples() == src.num_samples()
- dest.k() >= src.k() + sample_offset
- is_same_object(dest,src) == false
- sample_offset a count of elements, not bytes
ensures
- performs: dest[i, k + sample_offset, r, c] = src[i, k, r, c], where k in [0..src.k()]
Copies content of each sample from src in to corresponding place of sample at dst
!*/
void
split_depth
(
tensor
&
dest
,
size_t
sample_offset
,
const
tensor
&
src
);
/*!
requires
- dest.nc() == src.nc()
- dest.nr() == src.nr()
- dest.num_samples() == src.num_samples()
- dest.k() <= src.k() - sample_offset
- is_same_object(dest,src) == false
- sample_offset a count of elements, not bytes
ensures
- performs: dest[i, k, r, c] = src[i, k + sample_offset, r, c], where k in [0..dest.k()]
Fills each sample of dst from the corresponding part of each sample at src
!*/
void
copy_tensor
(
tensor
&
dest
,
size_t
dest_k_offset
,
const
tensor
&
src
,
size_t
src_k_offset
,
size_t
count_k
);
/*!
requires
- dest.nc() == src.nc()
- dest.nr() == src.nr()
- dest.num_samples() == src.num_samples()
- dest.k() - dest_k_offset >= count_k
- src.k() - src_k_offset >= count_k
- is_same_object(dest,src) == false
ensures
- performs: dest[i, k + dest_k_offset, r, c] = src[i, k + src_k_offset, r, c], where k in [0..count_k]
Copies content of each sample from src in to corresponding place of sample at dst
!*/
// ----------------------------------------------------------------------------------------
...
...
dlib/test/dnn.cpp
View file @
1f0318e2
...
...
@@ -12,7 +12,77 @@
#include "tester.h"
namespace
namespace
dlib
{
template
<
typename
SUBNET
>
using
concat_block1
=
con
<
5
,
1
,
1
,
1
,
1
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
concat_block2
=
con
<
8
,
3
,
3
,
1
,
1
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
concat_block3
=
max_pool
<
3
,
3
,
1
,
1
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
concat_incept
=
inception3
<
concat_block1
,
concat_block2
,
concat_block3
,
SUBNET
>
;
// this class is a friend of add_layer and can access private members
class
dnn_tester
{
public
:
// tester function is a member to have access to a private x_grad member of add_layer
static
void
test_concat
()
{
using
namespace
test
;
using
namespace
std
;
using
namespace
dlib
::
tt
;
print_spinner
();
using
net_type
=
concat_incept
<
input
<
matrix
<
float
>>>
;
resizable_tensor
data
(
10
,
1
,
111
,
222
);
data
=
matrix_cast
<
float
>
(
gaussian_randm
(
data
.
num_samples
(),
data
.
k
()
*
data
.
nr
()
*
data
.
nc
(),
1
));
net_type
net
;
auto
&
out
=
net
.
forward
(
data
);
auto
&
b1o
=
layer
<
itag1
>
(
net
).
get_output
();
auto
&
b2o
=
layer
<
itag2
>
(
net
).
get_output
();
auto
&
b3o
=
layer
<
itag3
>
(
net
).
get_output
();
resizable_tensor
dest
(
10
,
14
,
111
,
222
);
copy_tensor
(
dest
,
0
,
b1o
,
0
,
b1o
.
k
());
copy_tensor
(
dest
,
b1o
.
k
(),
b2o
,
0
,
b2o
.
k
());
copy_tensor
(
dest
,
b1o
.
k
()
+
b2o
.
k
(),
b3o
,
0
,
b3o
.
k
());
DLIB_TEST
(
dest
.
size
()
==
out
.
size
());
int
error
=
memcmp
(
dest
.
host
(),
out
.
host
(),
dest
.
size
());
DLIB_TEST
(
error
==
0
);
resizable_tensor
gr
(
10
,
14
,
111
,
222
);
gr
=
matrix_cast
<
float
>
(
gaussian_randm
(
gr
.
num_samples
(),
gr
.
k
()
*
gr
.
nr
()
*
gr
.
nc
(),
1
));
memcpy
(
net
.
get_gradient_input
(),
gr
);
net
.
back_propagate_error
(
data
);
auto
&
b1g
=
layer
<
itag1
>
(
net
).
subnet
().
x_grad
;
auto
&
b2g
=
layer
<
itag2
>
(
net
).
subnet
().
x_grad
;
auto
&
b3g
=
layer
<
itag3
>
(
net
).
subnet
().
x_grad
;
resizable_tensor
g1
(
10
,
5
,
111
,
222
);
resizable_tensor
g2
(
10
,
8
,
111
,
222
);
resizable_tensor
g3
(
10
,
1
,
111
,
222
);
copy_tensor
(
g1
,
0
,
gr
,
0
,
g1
.
k
());
copy_tensor
(
g2
,
0
,
gr
,
g1
.
k
(),
g2
.
k
());
copy_tensor
(
g3
,
0
,
gr
,
g1
.
k
()
+
g2
.
k
(),
g3
.
k
());
DLIB_TEST
(
g1
.
size
()
==
b1g
.
size
());
error
=
memcmp
(
g1
.
host
(),
b1g
.
host
(),
b1g
.
size
());
DLIB_TEST
(
error
==
0
);
DLIB_TEST
(
g2
.
size
()
==
b2g
.
size
());
error
=
memcmp
(
g2
.
host
(),
b2g
.
host
(),
b2g
.
size
());
DLIB_TEST
(
error
==
0
);
DLIB_TEST
(
g3
.
size
()
==
b3g
.
size
());
error
=
memcmp
(
g3
.
host
(),
b3g
.
host
(),
b3g
.
size
());
DLIB_TEST
(
error
==
0
);
}
};
}
namespace
{
using
namespace
test
;
...
...
@@ -1405,6 +1475,121 @@ namespace
DLIB_TEST
(
count
==
pnet
.
num_computational_layers
);
}
float
tensor_read_cpu
(
const
tensor
&
t
,
long
i
,
long
k
,
long
r
,
long
c
)
{
const
float
*
p
=
t
.
host
()
+
t
.
k
()
*
t
.
nr
()
*
t
.
nc
()
*
i
+
t
.
nr
()
*
t
.
nc
()
*
k
+
t
.
nc
()
*
r
+
c
;
return
*
p
;
}
void
test_copy_tensor_cpu
()
{
using
namespace
dlib
::
tt
;
print_spinner
();
resizable_tensor
dest
(
10
,
9
,
7
,
15
);
resizable_tensor
src1
(
10
,
3
,
7
,
15
);
resizable_tensor
src2
(
10
,
3
,
7
,
15
);
resizable_tensor
src3
(
10
,
9
,
7
,
15
);
dest
=
matrix_cast
<
float
>
(
gaussian_randm
(
dest
.
num_samples
(),
dest
.
k
()
*
dest
.
nr
()
*
dest
.
nc
(),
1
));
src1
=
matrix_cast
<
float
>
(
gaussian_randm
(
src1
.
num_samples
(),
src1
.
k
()
*
src1
.
nr
()
*
src1
.
nc
(),
0
));
src2
=
matrix_cast
<
float
>
(
gaussian_randm
(
src1
.
num_samples
(),
src2
.
k
()
*
src2
.
nr
()
*
src2
.
nc
(),
0
));
src3
=
matrix_cast
<
float
>
(
gaussian_randm
(
src1
.
num_samples
(),
src3
.
k
()
*
src3
.
nr
()
*
src3
.
nc
(),
0
));
cpu
::
copy_tensor
(
dest
,
0
,
src1
,
0
,
src1
.
k
());
//full copy src1->dest
cpu
::
copy_tensor
(
dest
,
src1
.
k
(),
src2
,
0
,
src2
.
k
());
//full copy src2->dest with offset of src1
cpu
::
copy_tensor
(
dest
,
src1
.
k
()
+
src2
.
k
(),
src3
,
3
,
3
);
//partial copy src3 into the rest place of dest
for
(
long
i
=
0
;
i
<
dest
.
num_samples
();
++
i
)
{
for
(
long
k
=
0
;
k
<
dest
.
k
();
++
k
)
{
for
(
long
r
=
0
;
r
<
dest
.
nr
();
++
r
)
{
for
(
long
c
=
0
;
c
<
dest
.
nc
();
++
c
)
{
float
dest_value
=
tensor_read_cpu
(
dest
,
i
,
k
,
r
,
c
);
// first part is from src1
if
(
k
<
src1
.
k
())
{
float
src_value
=
tensor_read_cpu
(
src1
,
i
,
k
,
r
,
c
);
DLIB_TEST
(
src_value
==
dest_value
);
}
// second part is from src2
else
if
(
k
<
src1
.
k
()
+
src2
.
k
())
{
float
src_value
=
tensor_read_cpu
(
src2
,
i
,
k
-
src1
.
k
(),
r
,
c
);
DLIB_TEST
(
src_value
==
dest_value
);
}
// third part is from src3
else
{
float
src_value
=
tensor_read_cpu
(
src3
,
i
,
k
-
src1
.
k
()
-
src2
.
k
()
+
3
,
r
,
c
);
DLIB_TEST
(
src_value
==
dest_value
);
}
}
}
}
}
}
#ifdef DLIB_USE_CUDA
float
tensor_read_gpu
(
const
tensor
&
t
,
long
i
,
long
k
,
long
r
,
long
c
)
{
const
float
*
p
=
t
.
device
()
+
t
.
k
()
*
t
.
nr
()
*
t
.
nc
()
*
i
+
t
.
nr
()
*
t
.
nc
()
*
k
+
t
.
nc
()
*
r
+
c
;
return
*
p
;
}
void
test_copy_tensor_gpu
()
{
using
namespace
dlib
::
tt
;
print_spinner
();
resizable_tensor
dest
(
10
,
9
,
7
,
15
);
resizable_tensor
src1
(
10
,
3
,
7
,
15
);
resizable_tensor
src2
(
10
,
3
,
7
,
15
);
resizable_tensor
src3
(
10
,
9
,
7
,
15
);
dest
=
matrix_cast
<
float
>
(
gaussian_randm
(
dest
.
num_samples
(),
dest
.
k
()
*
dest
.
nr
()
*
dest
.
nc
(),
1
));
src1
=
matrix_cast
<
float
>
(
gaussian_randm
(
src1
.
num_samples
(),
src1
.
k
()
*
src1
.
nr
()
*
src1
.
nc
(),
0
));
src2
=
matrix_cast
<
float
>
(
gaussian_randm
(
src1
.
num_samples
(),
src2
.
k
()
*
src2
.
nr
()
*
src2
.
nc
(),
0
));
src3
=
matrix_cast
<
float
>
(
gaussian_randm
(
src1
.
num_samples
(),
src3
.
k
()
*
src3
.
nr
()
*
src3
.
nc
(),
0
));
gpu
::
copy_tensor
(
dest
,
0
,
src1
,
0
,
src1
.
k
());
//full copy src1->dest
gpu
::
copy_tensor
(
dest
,
src1
.
k
(),
src2
,
0
,
src2
.
k
());
//full copy src2->dest with offset of src1
gpu
::
copy_tensor
(
dest
,
src1
.
k
()
+
src2
.
k
(),
src3
,
3
,
3
);
//partial copy src3 into the rest place of dest
for
(
long
i
=
0
;
i
<
dest
.
num_samples
();
++
i
)
{
for
(
long
k
=
0
;
k
<
dest
.
k
();
++
k
)
{
for
(
long
r
=
0
;
r
<
dest
.
nr
();
++
r
)
{
for
(
long
c
=
0
;
c
<
dest
.
nc
();
++
c
)
{
float
dest_value
=
tensor_read_gpu
(
dest
,
i
,
k
,
r
,
c
);
// first part is from src1
if
(
k
<
src1
.
k
())
{
float
src_value
=
tensor_read_gpu
(
src1
,
i
,
k
,
r
,
c
);
DLIB_TEST
(
src_value
==
dest_value
);
}
// second part is from src2
else
if
(
k
<
src1
.
k
()
+
src2
.
k
())
{
float
src_value
=
tensor_read_gpu
(
src2
,
i
,
k
-
src1
.
k
(),
r
,
c
);
DLIB_TEST
(
src_value
==
dest_value
);
}
// third part is from src3
else
{
float
src_value
=
tensor_read_gpu
(
src3
,
i
,
k
-
src1
.
k
()
-
src2
.
k
()
+
3
,
r
,
c
);
DLIB_TEST
(
src_value
==
dest_value
);
}
}
}
}
}
}
#endif//DLIB_USE_CUDA
// ----------------------------------------------------------------------------------------
class
dnn_tester
:
public
tester
...
...
@@ -1433,6 +1618,7 @@ namespace
compare_bn_conv_gpu_and_cpu
();
test_add
();
compare_adam
();
test_copy_tensor_gpu
();
#endif
test_max_pool
(
1
,
1
,
2
,
3
,
0
,
0
);
test_max_pool
(
3
,
3
,
1
,
1
,
0
,
0
);
...
...
@@ -1466,6 +1652,8 @@ namespace
test_basic_tensor_ops
();
test_layers
();
test_visit_funcions
();
test_copy_tensor_cpu
();
dlib
::
dnn_tester
::
test_concat
();
}
}
a
;
...
...
examples/dnn_inception_ex.cpp
View file @
1f0318e2
...
...
@@ -15,22 +15,42 @@
#include <dlib/dnn.h>
#include <iostream>
#include <dlib/data_io.h>
#include <tuple>
using
namespace
std
;
using
namespace
dlib
;
// Here we define inception module as described in GoogLeNet specification. The depth of each sublayer can be changed
template
<
typename
SUBNET
>
using
inception
=
grp
<
std
::
tuple
<
con
<
8
,
1
,
1
,
1
,
1
,
group_input
>
,
con
<
8
,
3
,
3
,
1
,
1
,
con
<
8
,
1
,
1
,
1
,
1
,
group_input
>>
,
con
<
8
,
5
,
5
,
1
,
1
,
con
<
8
,
1
,
1
,
1
,
1
,
group_input
>>
,
con
<
8
,
1
,
1
,
1
,
1
,
max_pool
<
3
,
3
,
1
,
1
,
group_input
>>>
,
SUBNET
>
;
// Inception layer has some different convolutions inside
// Here we define blocks as convolutions with different kernel size that we will use in
// inception layer block.
template
<
typename
SUBNET
>
using
block_a1
=
relu
<
con
<
4
,
1
,
1
,
1
,
1
,
SUBNET
>>
;
template
<
typename
SUBNET
>
using
block_a2
=
relu
<
con
<
4
,
3
,
3
,
1
,
1
,
relu
<
con
<
4
,
1
,
1
,
1
,
1
,
SUBNET
>>>>
;
template
<
typename
SUBNET
>
using
block_a3
=
relu
<
con
<
4
,
5
,
5
,
1
,
1
,
relu
<
con
<
4
,
1
,
1
,
1
,
1
,
SUBNET
>>>>
;
template
<
typename
SUBNET
>
using
block_a4
=
relu
<
con
<
4
,
1
,
1
,
1
,
1
,
max_pool
<
3
,
3
,
1
,
1
,
SUBNET
>>>
;
// Here is inception layer definition. It uses different blocks to process input and returns combined output
template
<
typename
SUBNET
>
using
incept_a
=
inception4
<
block_a1
,
block_a2
,
block_a3
,
block_a4
,
SUBNET
>
;
// Network can have inception layers of different structure.
// Here are blocks with different convolutions
template
<
typename
SUBNET
>
using
block_b1
=
relu
<
con
<
8
,
1
,
1
,
1
,
1
,
SUBNET
>>
;
template
<
typename
SUBNET
>
using
block_b2
=
relu
<
con
<
8
,
3
,
3
,
1
,
1
,
SUBNET
>>
;
template
<
typename
SUBNET
>
using
block_b3
=
relu
<
con
<
8
,
1
,
1
,
1
,
1
,
max_pool
<
3
,
3
,
1
,
1
,
SUBNET
>>>
;
// Here is inception layer definition. It uses different blocks to process input and returns combined output
template
<
typename
SUBNET
>
using
incept_b
=
inception3
<
block_b1
,
block_b2
,
block_b3
,
SUBNET
>
;
// and then the network type is
using
net_type
=
loss_multiclass_log
<
fc
<
10
,
relu
<
fc
<
32
,
max_pool
<
2
,
2
,
2
,
2
,
incept_b
<
max_pool
<
2
,
2
,
2
,
2
,
incept_a
<
input
<
matrix
<
unsigned
char
>>
>>>>>>>>
;
int
main
(
int
argc
,
char
**
argv
)
try
{
// This example is going to run on the MNIST dataset.
// This example is going to run on the MNIST dataset.
if
(
argc
!=
2
)
{
cout
<<
"This example needs the MNIST dataset to run!"
<<
endl
;
...
...
@@ -48,25 +68,10 @@ int main(int argc, char** argv) try
load_mnist_dataset
(
argv
[
1
],
training_images
,
training_labels
,
testing_images
,
testing_labels
);
// Create a the same network as in dnn_mnist_ex, but use inception layer insteam of convolution
// in the middle
using
net_type
=
loss_multiclass_log
<
fc
<
10
,
relu
<
fc
<
84
,
relu
<
fc
<
120
,
max_pool
<
2
,
2
,
2
,
2
,
relu
<
inception
<
max_pool
<
2
,
2
,
2
,
2
,
relu
<
con
<
6
,
5
,
5
,
1
,
1
,
input
<
matrix
<
unsigned
char
>>
>>>>>>>>>>>>
;
// Create a network as defined above. This network will produce 10 outputs
// because that's how we defined net_type. However, fc layers can have the
// number of outputs they produce changed at runtime.
// The rest of the sample is identical to dnn_minst_ex
// Create network of predefined type.
net_type
net
;
// the following training process is the same as in dnn_mnist_ex sample
// And then train it using the MNIST data. The code below uses mini-batch stochastic
// gradient descent with an initial learning rate of 0.01 to accomplish this.
dnn_trainer
<
net_type
>
trainer
(
net
);
...
...
@@ -80,12 +85,12 @@ int main(int argc, char** argv) try
// from scratch. This is because, when the program restarts, this call to
// set_synchronization_file() will automatically reload the settings from mnist_sync if
// the file exists.
trainer
.
set_synchronization_file
(
"
mnist
_sync"
,
std
::
chrono
::
seconds
(
20
));
trainer
.
set_synchronization_file
(
"
inception
_sync"
,
std
::
chrono
::
seconds
(
20
));
// Finally, this line begins training. By default, it runs SGD with our specified
// learning rate until the loss stops decreasing. Then it reduces the learning rate by
// a factor of 10 and continues running until the loss stops decreasing again. It will
// keep doing this until the learning rate has dropped below the min learning rate
// defined above or the maximum number of epochs as been executed (defaulted to 10000).
// defined above or the maximum number of epochs as been executed (defaulted to 10000).
trainer
.
train
(
training_images
,
training_labels
);
// At this point our net object should have learned how to classify MNIST images. But
...
...
@@ -96,7 +101,7 @@ int main(int argc, char** argv) try
// about that kind of transient data so that our file will be smaller. We do this by
// "cleaning" the network before saving it.
net
.
clean
();
serialize
(
"mnist_network.dat"
)
<<
net
;
serialize
(
"mnist_network
_inception
.dat"
)
<<
net
;
// Now if we later wanted to recall the network from disk we can simply say:
// deserialize("mnist_network.dat") >> net;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment