Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
e6437d7d
Commit
e6437d7d
authored
9 years ago
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Made the dnn_trainer check for convergence every iteration rather than only
once every few thousand iterations.
parent
0ecff0e6
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
10 deletions
+14
-10
trainer.h
dlib/dnn/trainer.h
+14
-10
No files found.
dlib/dnn/trainer.h
View file @
e6437d7d
...
@@ -396,7 +396,9 @@ namespace dlib
...
@@ -396,7 +396,9 @@ namespace dlib
{
{
double
loss
=
net
.
update
(
next_job
.
t
,
next_job
.
labels
.
begin
(),
make_sstack
(
solvers
),
step_size
);
double
loss
=
net
.
update
(
next_job
.
t
,
next_job
.
labels
.
begin
(),
make_sstack
(
solvers
),
step_size
);
rs
.
add
(
loss
);
rs
.
add
(
loss
);
rg
.
add
(
loss
);
previous_loss_values
.
push_back
(
loss
);
if
(
previous_loss_values
.
size
()
>
iter_between_step_size_adjust
)
previous_loss_values
.
pop_front
();
}
}
void
run_update
(
job_t
&
next_job
,
const
no_label_type
&
)
void
run_update
(
job_t
&
next_job
,
const
no_label_type
&
)
...
@@ -404,7 +406,9 @@ namespace dlib
...
@@ -404,7 +406,9 @@ namespace dlib
no_label_type
pick_wich_run_update
;
no_label_type
pick_wich_run_update
;
double
loss
=
net
.
update
(
next_job
.
t
,
make_sstack
(
solvers
),
step_size
);
double
loss
=
net
.
update
(
next_job
.
t
,
make_sstack
(
solvers
),
step_size
);
rs
.
add
(
loss
);
rs
.
add
(
loss
);
rg
.
add
(
loss
);
previous_loss_values
.
push_back
(
loss
);
if
(
previous_loss_values
.
size
()
>
iter_between_step_size_adjust
)
previous_loss_values
.
pop_front
();
}
}
void
thread
()
try
void
thread
()
try
...
@@ -422,13 +426,13 @@ namespace dlib
...
@@ -422,13 +426,13 @@ namespace dlib
// If we have been running for a while then check if the loss is still
// If we have been running for a while then check if the loss is still
// dropping. If it isn't then we will reduce the step size.
// dropping. If it isn't then we will reduce the step size.
if
(
rg
.
current_n
()
>
iter_between_step_size_adjust
)
if
(
previous_loss_values
.
size
()
>=
iter_between_step_size_adjust
)
{
{
if
(
rg
.
probability_gradient_greater_than
(
0
)
>
0
.
45
)
if
(
probability_gradient_greater_than
(
previous_loss_values
,
0
)
>
0
.
49
)
{
{
step_size
=
step_size_shrink
*
step_size
;
step_size
=
step_size_shrink
*
step_size
;
previous_loss_values
.
clear
();
}
}
rg
.
clear
();
}
}
}
}
}
}
...
@@ -470,13 +474,13 @@ namespace dlib
...
@@ -470,13 +474,13 @@ namespace dlib
friend
void
serialize
(
const
dnn_trainer
&
item
,
std
::
ostream
&
out
)
friend
void
serialize
(
const
dnn_trainer
&
item
,
std
::
ostream
&
out
)
{
{
item
.
wait_for_thread_to_pause
();
item
.
wait_for_thread_to_pause
();
int
version
=
3
;
int
version
=
4
;
serialize
(
version
,
out
);
serialize
(
version
,
out
);
size_t
nl
=
dnn_trainer
::
num_layers
;
size_t
nl
=
dnn_trainer
::
num_layers
;
serialize
(
nl
,
out
);
serialize
(
nl
,
out
);
serialize
(
item
.
rs
,
out
);
serialize
(
item
.
rs
,
out
);
serialize
(
item
.
rg
,
out
);
serialize
(
item
.
previous_loss_values
,
out
);
serialize
(
item
.
max_num_epochs
,
out
);
serialize
(
item
.
max_num_epochs
,
out
);
serialize
(
item
.
mini_batch_size
,
out
);
serialize
(
item
.
mini_batch_size
,
out
);
serialize
(
item
.
verbose
,
out
);
serialize
(
item
.
verbose
,
out
);
...
@@ -495,7 +499,7 @@ namespace dlib
...
@@ -495,7 +499,7 @@ namespace dlib
item
.
wait_for_thread_to_pause
();
item
.
wait_for_thread_to_pause
();
int
version
=
0
;
int
version
=
0
;
deserialize
(
version
,
in
);
deserialize
(
version
,
in
);
if
(
version
!=
3
)
if
(
version
!=
4
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::dnn_trainer."
);
throw
serialization_error
(
"Unexpected version found while deserializing dlib::dnn_trainer."
);
size_t
num_layers
=
0
;
size_t
num_layers
=
0
;
...
@@ -511,7 +515,7 @@ namespace dlib
...
@@ -511,7 +515,7 @@ namespace dlib
double
dtemp
;
long
ltemp
;
double
dtemp
;
long
ltemp
;
deserialize
(
item
.
rs
,
in
);
deserialize
(
item
.
rs
,
in
);
deserialize
(
item
.
rg
,
in
);
deserialize
(
item
.
previous_loss_values
,
in
);
deserialize
(
item
.
max_num_epochs
,
in
);
deserialize
(
item
.
max_num_epochs
,
in
);
deserialize
(
item
.
mini_batch_size
,
in
);
deserialize
(
item
.
mini_batch_size
,
in
);
deserialize
(
item
.
verbose
,
in
);
deserialize
(
item
.
verbose
,
in
);
...
@@ -562,7 +566,7 @@ namespace dlib
...
@@ -562,7 +566,7 @@ namespace dlib
dlib
::
pipe
<
job_t
>
job_pipe
;
dlib
::
pipe
<
job_t
>
job_pipe
;
running_stats
<
double
>
rs
;
running_stats
<
double
>
rs
;
running_gradient
rg
;
std
::
deque
<
double
>
previous_loss_values
;
unsigned
long
max_num_epochs
;
unsigned
long
max_num_epochs
;
size_t
mini_batch_size
;
size_t
mini_batch_size
;
bool
verbose
;
bool
verbose
;
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment