Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
6c36592c
Commit
6c36592c
authored
Oct 15, 2015
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added serialization support to everything.
parent
e679d66a
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
437 additions
and
69 deletions
+437
-69
core.h
dlib/dnn/core.h
+147
-1
core_abstract.h
dlib/dnn/core_abstract.h
+38
-0
input.h
dlib/dnn/input.h
+28
-0
input_abstract.h
dlib/dnn/input_abstract.h
+14
-0
layers.h
dlib/dnn/layers.h
+32
-67
layers_abstract.h
dlib/dnn/layers_abstract.h
+16
-0
loss.h
dlib/dnn/loss.h
+26
-0
loss_abstract.h
dlib/dnn/loss_abstract.h
+12
-0
solvers.h
dlib/dnn/solvers.h
+21
-0
solvers_abstract.h
dlib/dnn/solvers_abstract.h
+12
-0
tensor.h
dlib/dnn/tensor.h
+56
-0
trainer.h
dlib/dnn/trainer.h
+27
-1
trainer_abstract.h
dlib/dnn/trainer_abstract.h
+8
-0
No files found.
dlib/dnn/core.h
View file @
6c36592c
...
...
@@ -67,6 +67,18 @@ namespace dlib
const
sstack
<
T
,
N
-
1
>&
pop
()
const
{
return
data
;
}
sstack
<
T
,
N
-
1
>&
pop
()
{
return
data
;
}
friend
void
serialize
(
const
sstack
&
item
,
std
::
ostream
&
out
)
{
serialize
(
item
.
top
(),
out
);
serialize
(
item
.
pop
(),
out
);
}
friend
void
deserialize
(
sstack
&
item
,
std
::
istream
&
in
)
{
deserialize
(
item
.
top
(),
in
);
deserialize
(
item
.
pop
(),
in
);
}
private
:
T
item
;
sstack
<
T
,
N
-
1
>
data
;
...
...
@@ -83,6 +95,17 @@ namespace dlib
T
&
top
()
{
return
item
;
}
size_t
size
()
const
{
return
1
;
}
friend
void
serialize
(
const
sstack
&
item
,
std
::
ostream
&
out
)
{
serialize
(
item
.
top
(),
out
);
}
friend
void
deserialize
(
sstack
&
item
,
std
::
istream
&
in
)
{
deserialize
(
item
.
top
(),
in
);
}
private
:
T
item
;
};
...
...
@@ -294,6 +317,32 @@ namespace dlib
subnetwork
.
clean
();
}
friend
void
serialize
(
const
add_layer
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
subnetwork
,
out
);
serialize
(
item
.
details
,
out
);
serialize
(
item
.
this_layer_setup_called
,
out
);
serialize
(
item
.
gradient_input_is_stale
,
out
);
serialize
(
item
.
x_grad
,
out
);
serialize
(
item
.
cached_output
,
out
);
}
friend
void
deserialize
(
add_layer
&
item
,
std
::
istream
&
in
)
{
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::add_layer."
);
deserialize
(
item
.
subnetwork
,
in
);
deserialize
(
item
.
details
,
in
);
deserialize
(
item
.
this_layer_setup_called
,
in
);
deserialize
(
item
.
gradient_input_is_stale
,
in
);
deserialize
(
item
.
x_grad
,
in
);
deserialize
(
item
.
cached_output
,
in
);
}
private
:
...
...
@@ -468,6 +517,32 @@ namespace dlib
gradient_input_is_stale
=
true
;
}
friend
void
serialize
(
const
add_layer
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
input_layer
,
out
);
serialize
(
item
.
details
,
out
);
serialize
(
item
.
this_layer_setup_called
,
out
);
serialize
(
item
.
gradient_input_is_stale
,
out
);
serialize
(
item
.
x_grad
,
out
);
serialize
(
item
.
cached_output
,
out
);
}
friend
void
deserialize
(
add_layer
&
item
,
std
::
istream
&
in
)
{
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::add_layer."
);
deserialize
(
item
.
input_layer
,
in
);
deserialize
(
item
.
details
,
in
);
deserialize
(
item
.
this_layer_setup_called
,
in
);
deserialize
(
item
.
gradient_input_is_stale
,
in
);
deserialize
(
item
.
x_grad
,
in
);
deserialize
(
item
.
cached_output
,
in
);
}
private
:
class
subnet_wrapper
...
...
@@ -601,6 +676,22 @@ namespace dlib
subnetwork
.
clean
();
}
friend
void
serialize
(
const
add_tag_layer
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
subnetwork
,
out
);
}
friend
void
deserialize
(
add_tag_layer
&
item
,
std
::
istream
&
in
)
{
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::add_tag_layer."
);
deserialize
(
item
.
subnetwork
,
in
);
}
private
:
subnet_type
subnetwork
;
...
...
@@ -702,6 +793,26 @@ namespace dlib
cached_output
.
clear
();
}
friend
void
serialize
(
const
add_tag_layer
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
input_layer
,
out
);
serialize
(
item
.
cached_output
,
out
);
serialize
(
item
.
grad_final_ignored
,
out
);
}
friend
void
deserialize
(
add_tag_layer
&
item
,
std
::
istream
&
in
)
{
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::add_tag_layer."
);
deserialize
(
item
.
input_layer
,
in
);
deserialize
(
item
.
cached_output
,
in
);
deserialize
(
item
.
grad_final_ignored
,
in
);
}
private
:
subnet_type
input_layer
;
...
...
@@ -759,7 +870,8 @@ namespace dlib
const
static
unsigned
int
sample_expansion_factor
=
subnet_type
::
sample_expansion_factor
;
typedef
typename
get_loss_layer_label_type
<
LOSS_DETAILS
>::
type
label_type
;
static_assert
(
is_nonloss_layer_type
<
SUBNET
>::
value
,
"SUBNET must be of type add_layer, add_skip_layer, or add_tag_layer."
);
static_assert
(
is_nonloss_layer_type
<
SUBNET
>::
value
,
"SUBNET must be of type add_layer, add_skip_layer, or add_tag_layer."
);
static_assert
(
sample_expansion_factor
==
LOSS_DETAILS
::
sample_expansion_factor
,
"The loss layer and input layer must agree on the sample_expansion_factor."
);
...
...
@@ -947,6 +1059,24 @@ namespace dlib
subnetwork
.
clear
();
}
friend
void
serialize
(
const
add_loss_layer
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
loss
,
out
);
serialize
(
item
.
subnetwork
,
out
);
}
friend
void
deserialize
(
add_loss_layer
&
item
,
std
::
istream
&
in
)
{
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::add_loss_layer."
);
deserialize
(
item
.
loss
,
in
);
deserialize
(
item
.
subnetwork
,
in
);
}
private
:
loss_details_type
loss
;
...
...
@@ -1150,6 +1280,22 @@ namespace dlib
subnetwork
.
clean
();
}
friend
void
serialize
(
const
add_skip_layer
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
subnetwork
,
out
);
}
friend
void
deserialize
(
add_skip_layer
&
item
,
std
::
istream
&
in
)
{
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::add_skip_layer."
);
deserialize
(
item
.
subnetwork
,
in
);
}
private
:
subnet_type
subnetwork
;
...
...
dlib/dnn/core_abstract.h
View file @
6c36592c
...
...
@@ -119,6 +119,12 @@ namespace dlib
!*/
};
void
serialize
(
const
sstack
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
sstack
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
template
<
...
...
@@ -378,6 +384,14 @@ namespace dlib
};
template
<
typename
T
,
typename
U
>
,
void
serialize
(
const
add_layer
<
T
,
U
>&
item
,
std
::
ostream
&
out
);
template
<
typename
T
,
typename
U
>
,
void
deserialize
(
add_layer
<
T
,
U
>&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
...
...
@@ -769,6 +783,14 @@ namespace dlib
!*/
};
template
<
typename
T
,
typename
U
>
,
void
serialize
(
const
add_loss_layer
<
T
,
U
>&
item
,
std
::
ostream
&
out
);
template
<
typename
T
,
typename
U
>
,
void
deserialize
(
add_loss_layer
<
T
,
U
>&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
...
...
@@ -799,6 +821,14 @@ namespace dlib
!*/
};
template
<
unsigned
long
ID
,
typename
U
>
,
void
serialize
(
const
add_tag_layer
<
ID
,
U
>&
item
,
std
::
ostream
&
out
);
template
<
unsigned
long
ID
,
typename
U
>
,
void
deserialize
(
add_tag_layer
<
ID
,
U
>&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
template
<
typename
SUBNET
>
using
tag1
=
add_tag_layer
<
1
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
tag2
=
add_tag_layer
<
2
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
tag3
=
add_tag_layer
<
3
,
SUBNET
>
;
...
...
@@ -834,6 +864,14 @@ namespace dlib
!*/
};
template
<
template
<
typename
>
class
T
,
typename
U
>
void
serialize
(
const
add_skip_layer
<
T
,
U
>&
item
,
std
::
ostream
&
out
);
template
<
template
<
typename
>
class
T
,
typename
U
>
void
deserialize
(
add_skip_layer
<
T
,
U
>&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
template
<
typename
SUBNET
>
using
skip1
=
add_skip_layer
<
tag1
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
skip2
=
add_skip_layer
<
tag2
,
SUBNET
>
;
template
<
typename
SUBNET
>
using
skip3
=
add_skip_layer
<
tag3
,
SUBNET
>
;
...
...
dlib/dnn/input.h
View file @
6c36592c
...
...
@@ -73,6 +73,20 @@ namespace dlib
}
}
friend
void
serialize
(
const
input
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"input<matrix>"
,
out
);
}
friend
void
deserialize
(
input
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"input<matrix>"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::input."
);
}
};
// ----------------------------------------------------------------------------------------
...
...
@@ -126,6 +140,20 @@ namespace dlib
}
}
friend
void
serialize
(
const
input
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"input<array2d>"
,
out
);
}
friend
void
deserialize
(
input
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"input<array2d>"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::input."
);
}
};
// ----------------------------------------------------------------------------------------
...
...
dlib/dnn/input_abstract.h
View file @
6c36592c
...
...
@@ -86,6 +86,12 @@ namespace dlib
!*/
};
void
serialize
(
const
EXAMPLE_INPUT_LAYER
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
EXAMPLE_INPUT_LAYER
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
template
<
...
...
@@ -132,6 +138,14 @@ namespace dlib
!*/
};
template
<
typename
T
>
void
serialize
(
const
input
<
T
>&
item
,
std
::
ostream
&
out
);
template
<
typename
T
>
void
deserialize
(
input
<
T
>&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
}
...
...
dlib/dnn/layers.h
View file @
6c36592c
...
...
@@ -59,13 +59,12 @@ namespace dlib
public
:
fc_
()
:
num_outputs
(
1
)
{
rnd
.
set_seed
(
"fc_"
+
cast_to_string
(
num_outputs
));
}
explicit
fc_
(
unsigned
long
num_outputs_
)
explicit
fc_
(
unsigned
long
num_outputs_
)
:
num_outputs
(
num_outputs_
)
{
num_outputs
=
num_outputs_
;
rnd
.
set_seed
(
"fc_"
+
cast_to_string
(
num_outputs
));
}
unsigned
long
get_num_outputs
(
...
...
@@ -77,6 +76,7 @@ namespace dlib
num_inputs
=
sub
.
get_output
().
nr
()
*
sub
.
get_output
().
nc
()
*
sub
.
get_output
().
k
();
params
.
set_size
(
num_inputs
,
num_outputs
);
dlib
::
rand
rnd
(
"fc_"
+
cast_to_string
(
num_outputs
));
randomize_parameters
(
params
,
num_inputs
+
num_outputs
,
rnd
);
}
...
...
@@ -101,12 +101,30 @@ namespace dlib
const
tensor
&
get_layer_params
()
const
{
return
params
;
}
tensor
&
get_layer_params
()
{
return
params
;
}
friend
void
serialize
(
const
fc_
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"fc_"
,
out
);
serialize
(
item
.
num_outputs
,
out
);
serialize
(
item
.
num_inputs
,
out
);
serialize
(
item
.
params
,
out
);
}
friend
void
deserialize
(
fc_
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"fc_"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::fc_."
);
deserialize
(
item
.
num_outputs
,
in
);
deserialize
(
item
.
num_inputs
,
in
);
deserialize
(
item
.
params
,
in
);
}
private
:
unsigned
long
num_outputs
;
unsigned
long
num_inputs
;
resizable_tensor
params
;
dlib
::
rand
rnd
;
};
...
...
@@ -151,81 +169,28 @@ namespace dlib
const
tensor
&
get_layer_params
()
const
{
return
params
;
}
tensor
&
get_layer_params
()
{
return
params
;
}
private
:
resizable_tensor
params
;
};
template
<
typename
SUBNET
>
using
relu
=
add_layer
<
relu_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
class
multiply_
{
public
:
multiply_
()
{
}
template
<
typename
SUBNET
>
void
setup
(
const
SUBNET
&
sub
)
{
num_inputs
=
sub
.
get_output
().
nr
()
*
sub
.
get_output
().
nc
()
*
sub
.
get_output
().
k
();
params
.
set_size
(
1
,
num_inputs
);
std
::
cout
<<
"multiply_::setup() "
<<
params
.
size
()
<<
std
::
endl
;
const
int
num_outputs
=
num_inputs
;
randomize_parameters
(
params
,
num_inputs
+
num_outputs
,
rnd
);
}
template
<
typename
SUBNET
>
void
forward
(
const
SUBNET
&
sub
,
resizable_tensor
&
output
)
{
DLIB_CASSERT
(
sub
.
get_output
().
nr
()
*
sub
.
get_output
().
nc
()
*
sub
.
get_output
().
k
()
==
params
.
size
(),
""
);
DLIB_CASSERT
(
sub
.
get_output
().
nr
()
*
sub
.
get_output
().
nc
()
*
sub
.
get_output
().
k
()
==
num_inputs
,
""
);
output
.
copy_size
(
sub
.
get_output
());
auto
indata
=
sub
.
get_output
().
host
();
auto
outdata
=
output
.
host
();
auto
paramdata
=
params
.
host
();
for
(
int
i
=
0
;
i
<
sub
.
get_output
().
num_samples
();
++
i
)
friend
void
serialize
(
const
relu_
&
item
,
std
::
ostream
&
out
)
{
for
(
int
j
=
0
;
j
<
num_inputs
;
++
j
)
{
*
outdata
++
=
*
indata
++
*
paramdata
[
j
];
}
}
serialize
(
"relu_"
,
out
);
}
template
<
typename
SUBNET
>
void
backward
(
const
tensor
&
gradient_input
,
SUBNET
&
sub
,
tensor
&
params_grad
)
{
params_grad
+=
sum_rows
(
pointwise_multiply
(
mat
(
sub
.
get_output
()),
mat
(
gradient_input
)));
for
(
long
i
=
0
;
i
<
gradient_input
.
num_samples
();
++
i
)
friend
void
deserialize
(
relu_
&
item
,
std
::
istream
&
in
)
{
sub
.
get_gradient_input
().
add_to_sample
(
i
,
pointwise_multiply
(
rowm
(
mat
(
gradient_input
),
i
),
mat
(
params
)));
}
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"relu_"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::relu_."
);
}
const
tensor
&
get_layer_params
()
const
{
return
params
;
}
tensor
&
get_layer_params
()
{
return
params
;
}
private
:
int
num_inputs
;
resizable_tensor
params
;
dlib
::
rand
rnd
;
};
template
<
typename
SUBNET
>
using
multiply
=
add_layer
<
multiply
_
,
SUBNET
>
;
using
relu
=
add_layer
<
relu
_
,
SUBNET
>
;
// ----------------------------------------------------------------------------------------
...
...
dlib/dnn/layers_abstract.h
View file @
6c36592c
...
...
@@ -218,6 +218,12 @@ namespace dlib
};
void
serialize
(
const
EXAMPLE_LAYER_
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
EXAMPLE_LAYER_
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// For each layer you define, always define an add_layer template so that layers can be
// easily composed. Moreover, the convention is that the layer class ends with an _
// while the add_layer template has the same name but without the trailing _.
...
...
@@ -274,6 +280,11 @@ namespace dlib
!*/
};
void
serialize
(
const
fc_
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
fc_
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
template
<
typename
SUBNET
>
using
fc
=
add_layer
<
fc_
,
SUBNET
>
;
...
...
@@ -306,6 +317,11 @@ namespace dlib
!*/
};
void
serialize
(
const
relu_
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
relu_
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
template
<
typename
SUBNET
>
using
relu
=
add_layer
<
relu_
,
SUBNET
>
;
...
...
dlib/dnn/loss.h
View file @
6c36592c
...
...
@@ -81,6 +81,19 @@ namespace dlib
return
loss
;
}
friend
void
serialize
(
const
loss_binary_hinge_
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"loss_binary_hinge_"
,
out
);
}
friend
void
deserialize
(
loss_binary_hinge_
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"loss_binary_hinge_"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::loss_binary_hinge_."
);
}
};
template
<
typename
SUBNET
>
...
...
@@ -105,6 +118,19 @@ namespace dlib
return
0
;
}
friend
void
serialize
(
const
loss_no_label_
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"loss_no_label_"
,
out
);
}
friend
void
deserialize
(
loss_no_label_
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"loss_no_label_"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::loss_no_label_."
);
}
};
template
<
typename
SUBNET
>
...
...
dlib/dnn/loss_abstract.h
View file @
6c36592c
...
...
@@ -118,6 +118,12 @@ namespace dlib
!*/
};
void
serialize
(
const
EXAMPLE_LOSS_LAYER_
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
EXAMPLE_LOSS_LAYER_
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// For each loss layer you define, always define an add_loss_layer template so that
// layers can be easily composed. Moreover, the convention is that the layer class
// ends with an _ while the add_loss_layer template has the same name but without the
...
...
@@ -179,6 +185,12 @@ namespace dlib
};
void
serialize
(
const
loss_binary_hinge_
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
loss_binary_hinge_
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
template
<
typename
SUBNET
>
using
loss_binary_hinge
=
add_loss_layer
<
loss_binary_hinge_
,
SUBNET
>
;
...
...
dlib/dnn/solvers.h
View file @
6c36592c
...
...
@@ -48,6 +48,27 @@ namespace dlib
l
.
get_layer_params
()
+=
v
;
}
friend
void
serialize
(
const
sgd
&
item
,
std
::
ostream
&
out
)
{
serialize
(
"sgd"
,
out
);
serialize
(
item
.
v
,
out
);
serialize
(
item
.
weight_decay
,
out
);
serialize
(
item
.
learning_rate
,
out
);
serialize
(
item
.
momentum
,
out
);
}
friend
void
deserialize
(
sgd
&
item
,
std
::
istream
&
in
)
{
std
::
string
version
;
deserialize
(
version
,
in
);
if
(
version
!=
"sgd"
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::sgd."
);
deserialize
(
item
.
v
,
in
);
deserialize
(
item
.
weight_decay
,
in
);
deserialize
(
item
.
learning_rate
,
in
);
deserialize
(
item
.
momentum
,
in
);
}
private
:
matrix
<
float
>
v
;
float
weight_decay
;
...
...
dlib/dnn/solvers_abstract.h
View file @
6c36592c
...
...
@@ -52,6 +52,12 @@ namespace dlib
!*/
};
void
serialize
(
const
EXAMPLE_SOLVER
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
EXAMPLE_SOLVER
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
...
...
@@ -92,6 +98,12 @@ namespace dlib
float
get_momentum
()
const
;
};
void
serialize
(
const
sgd
&
item
,
std
::
ostream
&
out
);
void
deserialize
(
sgd
&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
}
...
...
dlib/dnn/tensor.h
View file @
6c36592c
...
...
@@ -112,6 +112,7 @@ namespace dlib
size_t
size
()
const
{
return
data_size
;
}
private
:
void
copy_to_device
()
const
...
...
@@ -144,6 +145,30 @@ namespace dlib
std
::
unique_ptr
<
float
[]
>
data_device
;
};
inline
void
serialize
(
const
gpu_data
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
item
.
size
(),
out
);
auto
data
=
item
.
host
();
for
(
size_t
i
=
0
;
i
<
item
.
size
();
++
i
)
serialize
(
data
[
i
],
out
);
}
inline
void
deserialize
(
gpu_data
&
item
,
std
::
istream
&
in
)
{
int
version
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::gpu_data."
);
size_t
s
;
deserialize
(
s
,
in
);
item
.
set_size
(
s
);
auto
data
=
item
.
host
();
for
(
size_t
i
=
0
;
i
<
item
.
size
();
++
i
)
deserialize
(
data
[
i
],
in
);
}
// ----------------------------------------------------------------------------------------
class
tensor
...
...
@@ -466,6 +491,37 @@ namespace dlib
}
};
inline
void
serialize
(
const
tensor
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
num_samples
(),
out
);
serialize
(
item
.
nr
(),
out
);
serialize
(
item
.
nc
(),
out
);
serialize
(
item
.
k
(),
out
);
auto
data
=
item
.
host
();
for
(
size_t
i
=
0
;
i
<
item
.
size
();
++
i
)
serialize
(
data
[
i
],
out
);
}
inline
void
deserialize
(
resizable_tensor
&
item
,
std
::
istream
&
in
)
{
int
version
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::resizable_tensor."
);
long
num_samples
=
0
,
nr
=
0
,
nc
=
0
,
k
=
0
;
deserialize
(
num_samples
,
in
);
deserialize
(
nr
,
in
);
deserialize
(
nc
,
in
);
deserialize
(
k
,
in
);
item
.
set_size
(
num_samples
,
nr
,
nc
,
k
);
auto
data
=
item
.
host
();
for
(
size_t
i
=
0
;
i
<
item
.
size
();
++
i
)
deserialize
(
data
[
i
],
in
);
}
// ----------------------------------------------------------------------------------------
inline
double
dot
(
...
...
dlib/dnn/trainer.h
View file @
6c36592c
...
...
@@ -9,6 +9,7 @@
#include "../statistics.h"
#include "../console_progress_indicator.h"
#include <chrono>
#include "../serialize.h"
namespace
dlib
{
...
...
@@ -281,8 +282,34 @@ namespace dlib
return
net
;
}
friend
void
serialize
(
const
dnn_trainer
&
item
,
std
::
ostream
&
out
)
{
int
version
=
1
;
serialize
(
version
,
out
);
serialize
(
item
.
num_epochs
,
out
);
serialize
(
item
.
mini_batch_size
,
out
);
serialize
(
item
.
verbose
,
out
);
serialize
(
item
.
net
,
out
);
serialize
(
item
.
solvers
,
out
);
}
friend
void
deserialize
(
dnn_trainer
&
item
,
std
::
istream
&
in
)
{
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
1
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::dnn_trainer."
);
deserialize
(
item
.
num_epochs
,
in
);
deserialize
(
item
.
mini_batch_size
,
in
);
deserialize
(
item
.
verbose
,
in
);
deserialize
(
item
.
net
,
in
);
deserialize
(
item
.
solvers
,
in
);
}
private
:
const
static
long
string_pad
=
10
;
void
init
()
{
num_epochs
=
300
;
...
...
@@ -293,7 +320,6 @@ namespace dlib
unsigned
long
num_epochs
;
unsigned
long
mini_batch_size
;
bool
verbose
;
const
static
long
string_pad
=
10
;
net_type
net
;
sstack
<
solver_type
,
net_type
::
num_layers
>
solvers
;
...
...
dlib/dnn/trainer_abstract.h
View file @
6c36592c
...
...
@@ -222,6 +222,14 @@ namespace dlib
};
template
<
typename
T
,
typename
U
>
void
serialize
(
const
dnn_trainer
<
T
,
U
>&
item
,
std
::
ostream
&
out
);
template
<
typename
T
,
typename
U
>
void
deserialize
(
dnn_trainer
<
T
,
U
>&
item
,
std
::
istream
&
in
);
/*!
provides serialization support
!*/
// ----------------------------------------------------------------------------------------
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment