Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
ca11d108
Commit
ca11d108
authored
Apr 19, 2016
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added multi-gpu support to the dnn_trainer
parent
b9fd9564
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
285 additions
and
96 deletions
+285
-96
trainer.h
dlib/dnn/trainer.h
+268
-70
trainer_abstract.h
dlib/dnn/trainer_abstract.h
+17
-26
No files found.
dlib/dnn/trainer.h
View file @
ca11d108
...
...
@@ -18,12 +18,38 @@
#include "../statistics/running_gradient.h"
#include <atomic>
#include <cstdio>
#include <set>
#include <future>
namespace
dlib
{
// ----------------------------------------------------------------------------------------
namespace
impl
{
template
<
typename
label_type
>
struct
dnn_job_t
{
dnn_job_t
()
=
default
;
dnn_job_t
(
const
dnn_job_t
&
)
=
delete
;
dnn_job_t
&
operator
=
(
const
dnn_job_t
&
)
=
delete
;
std
::
vector
<
std
::
vector
<
label_type
>>
labels
;
std
::
vector
<
resizable_tensor
>
t
;
std
::
vector
<
int
>
have_data
;
// have_data[i] is true if there is data in labels[i] and t[i].
};
template
<
typename
label_type
>
void
swap
(
dnn_job_t
<
label_type
>&
a
,
dnn_job_t
<
label_type
>&
b
)
{
a
.
labels
.
swap
(
b
.
labels
);
a
.
t
.
swap
(
b
.
t
);
a
.
have_data
.
swap
(
b
.
have_data
);
}
}
template
<
typename
net_type
,
typename
solver_type
=
sgd
...
...
@@ -38,20 +64,59 @@ namespace dlib
typedef
typename
net_type
::
label_type
label_type
;
typedef
typename
net_type
::
input_type
input_type
;
const
static
size_t
num_computational_layers
=
net_type
::
num_computational_layers
;
const
static
size_t
num_layers
=
net_type
::
num_layers
;
private
:
typedef
impl
::
dnn_job_t
<
label_type
>
job_t
;
public
:
dnn_trainer
()
=
delete
;
dnn_trainer
(
const
dnn_trainer
&
)
=
delete
;
dnn_trainer
&
operator
=
(
const
dnn_trainer
&
)
=
delete
;
explicit
dnn_trainer
(
net_type
&
net_
)
:
job_pipe
(
0
),
net
(
net_
)
,
solvers
(
num_computational_layers
)
explicit
dnn_trainer
(
net_type
&
net_
)
:
job_pipe
(
0
),
net
(
net_
)
{
solver_type
default_solver
;
devices
.
push_back
(
std
::
make_shared
<
device_data
>
(
dlib
::
cuda
::
get_device
(),
net
,
default_solver
));
init
();
}
dnn_trainer
(
net_type
&
net_
,
const
solver_type
&
solver_
)
:
job_pipe
(
0
),
net
(
net_
),
solvers
(
num_computational_layers
,
solver_
)
)
:
job_pipe
(
0
),
net
(
net_
)
{
devices
.
push_back
(
std
::
make_shared
<
device_data
>
(
dlib
::
cuda
::
get_device
(),
net
,
solver_
));
init
();
}
dnn_trainer
(
net_type
&
net_
,
const
solver_type
&
solver_
,
const
std
::
vector
<
int
>&
cuda_extra_devices
)
:
job_pipe
(
0
),
net
(
net_
)
{
devices
.
push_back
(
std
::
make_shared
<
device_data
>
(
dlib
::
cuda
::
get_device
(),
net
,
solver_
));
const
int
total_devices
=
dlib
::
cuda
::
get_num_devices
();
// Make device contexts for the extra device ids but be careful to avoid any
// duplicate ids.
std
::
set
<
int
>
temp
(
cuda_extra_devices
.
begin
(),
cuda_extra_devices
.
end
());
temp
.
erase
(
devices
[
0
]
->
device_id
);
for
(
auto
id
:
temp
)
{
DLIB_CASSERT
(
0
<=
id
&&
id
<
total_devices
,
"Invalid CUDA device id given to dnn_trainer."
);
// Switch to this device so that any tensor objects that get allocated when
// we create the device context happen on this device.
dlib
::
cuda
::
set_device
(
id
);
devices
.
push_back
(
std
::
make_shared
<
device_data
>
(
id
,
net
,
solver_
,
clone_net
()));
}
// Set the current device back to what it was before this constructor was
// called.
dlib
::
cuda
::
set_device
(
devices
[
0
]
->
device_id
);
init
();
}
...
...
@@ -70,13 +135,6 @@ namespace dlib
return
net
;
}
void
set_solver
(
const
solver_type
&
solver_
)
{
wait_for_thread_to_pause
();
solvers
=
std
::
vector
<
solver_type
>
(
num_computational_layers
,
solver_
);
}
unsigned
long
get_mini_batch_size
(
)
const
{
return
mini_batch_size
;
}
...
...
@@ -117,22 +175,16 @@ namespace dlib
)
const
{
wait_for_thread_to_pause
();
return
solvers
;
return
devices
[
0
]
->
solvers
;
}
std
::
vector
<
solver_type
>&
get_solvers
(
)
{
wait_for_thread_to_pause
();
return
solvers
;
}
void
train_one_step
(
const
std
::
vector
<
input_type
>&
data
,
const
std
::
vector
<
label_type
>&
labels
)
{
DLIB_CASSERT
(
data
.
size
()
==
labels
.
size
()
&&
data
.
size
()
>
0
,
""
);
if
(
verbose
)
{
using
namespace
std
::
chrono
;
...
...
@@ -149,9 +201,8 @@ namespace dlib
}
}
sync_to_disk
();
job
.
labels
=
labels
;
net
.
to_tensor
(
data
.
begin
(),
data
.
end
(),
job
.
t
);
job_pipe
.
enqueue
(
job
);
send_job
(
data
.
begin
(),
data
.
end
(),
labels
.
begin
());
++
train_one_step_calls
;
}
...
...
@@ -159,6 +210,7 @@ namespace dlib
const
std
::
vector
<
input_type
>&
data
)
{
DLIB_CASSERT
(
data
.
size
()
>
0
,
""
);
if
(
verbose
)
{
using
namespace
std
::
chrono
;
...
...
@@ -175,8 +227,7 @@ namespace dlib
}
}
sync_to_disk
();
net
.
to_tensor
(
data
.
begin
(),
data
.
end
(),
job
.
t
);
job_pipe
.
enqueue
(
job
);
send_job
(
data
.
begin
(),
data
.
end
());
++
train_one_step_calls
;
}
...
...
@@ -216,12 +267,9 @@ namespace dlib
}
sync_to_disk
();
net
.
to_tensor
(
data
.
begin
()
+
epoch_pos
,
data
.
begin
()
+
std
::
min
(
epoch_pos
+
mini_batch_size
,
data
.
size
()),
job
.
t
);
job
.
labels
.
assign
(
labels
.
begin
()
+
epoch_pos
,
labels
.
begin
()
+
std
::
min
(
epoch_pos
+
mini_batch_size
,
data
.
size
()));
job_pipe
.
enqueue
(
job
);
send_job
(
data
.
begin
()
+
epoch_pos
,
data
.
begin
()
+
std
::
min
(
epoch_pos
+
mini_batch_size
,
data
.
size
()),
labels
.
begin
()
+
epoch_pos
);
updated_the_network
=
true
;
}
epoch_pos
=
0
;
...
...
@@ -281,10 +329,8 @@ namespace dlib
}
sync_to_disk
();
net
.
to_tensor
(
data
.
begin
()
+
epoch_pos
,
data
.
begin
()
+
std
::
min
(
epoch_pos
+
mini_batch_size
,
data
.
size
()),
job
.
t
);
job_pipe
.
enqueue
(
job
);
send_job
(
data
.
begin
()
+
epoch_pos
,
data
.
begin
()
+
std
::
min
(
epoch_pos
+
mini_batch_size
,
data
.
size
()));
updated_the_network
=
true
;
}
epoch_pos
=
0
;
...
...
@@ -393,11 +439,6 @@ namespace dlib
}
private
:
struct
job_t
{
std
::
vector
<
label_type
>
labels
;
resizable_tensor
t
;
};
void
record_loss
(
double
loss
)
{
...
...
@@ -416,34 +457,98 @@ namespace dlib
}
template
<
typename
T
>
void
run_update
(
job_t
&
next_job
,
const
T
&
)
double
compute_parameter_gradients
(
size_t
device
,
job_t
&
next_job
,
const
T
&
)
{
double
loss
=
net
.
compute_parameter_gradients
(
next_job
.
t
,
next_job
.
labels
.
begin
());
net
.
update_parameters
(
make_sstack
(
solvers
),
step_size
);
record_loss
(
loss
);
if
(
next_job
.
have_data
[
device
])
{
auto
&&
dev
=
*
devices
[
device
];
dlib
::
cuda
::
set_device
(
dev
.
device_id
);
return
dev
.
net
.
compute_parameter_gradients
(
next_job
.
t
[
device
],
next_job
.
labels
[
device
].
begin
());
}
else
{
return
0
;
}
}
void
run_update
(
job_t
&
next_job
,
const
no_label_type
&
)
double
compute_parameter_gradients
(
size_t
device
,
job_t
&
next_job
,
const
no_label_type
&
)
{
no_label_type
pick_which_run_update
;
double
loss
=
net
.
compute_parameter_gradients
(
next_job
.
t
);
net
.
update_parameters
(
make_sstack
(
solvers
),
step_size
);
record_loss
(
loss
);
if
(
next_job
.
have_data
[
device
])
{
auto
&&
dev
=
*
devices
[
device
];
dlib
::
cuda
::
set_device
(
dev
.
device_id
);
no_label_type
pick_which_run_update
;
return
dev
.
net
.
compute_parameter_gradients
(
next_job
.
t
[
device
]);
}
else
{
return
0
;
}
}
void
update_parameters
(
size_t
device
)
{
auto
&&
dev
=
*
devices
[
device
];
dlib
::
cuda
::
set_device
(
dev
.
device_id
);
dev
.
net
.
update_parameters
(
make_sstack
(
dev
.
solvers
),
step_size
);
}
void
thread
()
try
{
// Make sure this thread uses the same cuda device as the thread that created
// the dnn_trainer object.
dlib
::
cuda
::
set_device
(
cuda_device_id
);
label_type
pick_which_run_update
;
job_t
next_job
;
std
::
vector
<
std
::
future
<
double
>>
losses
(
devices
.
size
());
std
::
vector
<
std
::
future
<
void
>>
update_futs
(
devices
.
size
());
std
::
vector
<
matrix
<
float
>>
param_buffer
(
net_type
::
num_computational_layers
);
while
(
job_pipe
.
dequeue
(
next_job
))
{
// call net.compute_parameter_gradients() and net.update_parameters() but
// pick the right version for unsupervised or supervised training based on
// the type of label_type.
run_update
(
next_job
,
pick_which_run_update
);
// Call compute_parameter_gradients() and update_parameters() but pick the
// right version for unsupervised or supervised training based on the type
// of label_type.
for
(
size_t
i
=
0
;
i
<
devices
.
size
();
++
i
)
losses
[
i
]
=
std
::
async
(
std
::
launch
::
async
,[
&
,
i
](){
return
compute_parameter_gradients
(
i
,
next_job
,
pick_which_run_update
);
});
// aggregate loss values from all the network computations.
for
(
auto
&&
loss
:
losses
)
record_loss
(
loss
.
get
());
// Now, if there is more than one active device we need to synchronize the
// gradient updates between devices. So we do that now.
if
(
devices
.
size
()
>
1
)
{
for
(
auto
&&
p
:
param_buffer
)
p
=
0
;
// now average all the parameter gradients
for
(
size_t
i
=
0
;
i
<
devices
.
size
();
++
i
)
{
visit_layer_parameters
(
devices
[
i
]
->
net
,
[
&
param_buffer
](
size_t
j
,
tensor
&
t
)
{
if
(
t
.
size
()
!=
0
)
param_buffer
[
j
]
+=
mat
(
t
);
});
}
// and then assign the parameter gradients back to all the networks
const
float
scale
=
1
.
0
f
/
devices
.
size
();
for
(
size_t
i
=
0
;
i
<
devices
.
size
();
++
i
)
{
visit_layer_parameters
(
devices
[
i
]
->
net
,
[
scale
,
&
param_buffer
](
size_t
j
,
tensor
&
t
)
{
if
(
t
.
size
()
!=
0
)
{
t
=
param_buffer
[
j
]
*
scale
;
t
.
async_copy_to_device
();
}
});
}
}
// Now apply all the updates to each device.
for
(
size_t
i
=
0
;
i
<
devices
.
size
();
++
i
)
update_futs
[
i
]
=
std
::
async
(
std
::
launch
::
async
,
[
&
,
i
](){
if
(
next_job
.
have_data
[
i
])
update_parameters
(
i
);
});
// and wait for the updates to all happen.
for
(
auto
&&
f
:
update_futs
)
f
.
wait
();
// If we have been running for a while then check if the loss is still
// dropping. If it isn't then we will reduce the step size. Note that we
...
...
@@ -484,7 +589,6 @@ namespace dlib
max_num_epochs
=
10000
;
mini_batch_size
=
128
;
verbose
=
false
;
cuda_device_id
=
dlib
::
cuda
::
get_device
();
step_size
=
1
;
min_step_size
=
1e-3
;
iter_without_progress_thresh
=
2000
;
...
...
@@ -504,10 +608,10 @@ namespace dlib
friend
void
serialize
(
const
dnn_trainer
&
item
,
std
::
ostream
&
out
)
{
item
.
wait_for_thread_to_pause
();
int
version
=
5
;
int
version
=
6
;
serialize
(
version
,
out
);
size_t
nl
=
dnn_trainer
::
num_
computational_
layers
;
size_t
nl
=
dnn_trainer
::
num_layers
;
serialize
(
nl
,
out
);
serialize
(
item
.
rs
,
out
);
serialize
(
item
.
previous_loss_values
,
out
);
...
...
@@ -515,7 +619,7 @@ namespace dlib
serialize
(
item
.
mini_batch_size
,
out
);
serialize
(
item
.
verbose
,
out
);
serialize
(
item
.
net
,
out
);
serialize
(
item
.
solvers
,
out
);
serialize
(
item
.
devices
[
0
]
->
solvers
,
out
);
serialize
(
item
.
step_size
.
load
(),
out
);
serialize
(
item
.
min_step_size
,
out
);
serialize
(
item
.
iter_without_progress_thresh
.
load
(),
out
);
...
...
@@ -530,17 +634,17 @@ namespace dlib
item
.
wait_for_thread_to_pause
();
int
version
=
0
;
deserialize
(
version
,
in
);
if
(
version
!=
5
)
if
(
version
!=
6
)
throw
serialization_error
(
"Unexpected version found while deserializing dlib::dnn_trainer."
);
size_t
num_
computational_
layers
=
0
;
deserialize
(
num_
computational_
layers
,
in
);
if
(
num_
computational_layers
!=
dnn_trainer
::
num_computational
_layers
)
size_t
num_layers
=
0
;
deserialize
(
num_layers
,
in
);
if
(
num_
layers
!=
dnn_trainer
::
num
_layers
)
{
std
::
ostringstream
sout
;
sout
<<
"Error deserializing dlib::dnn_trainer. The saved sync file is for a network with "
<<
std
::
endl
;
sout
<<
"a different number of layers. We expected the number of
computational layers to be "
<<
dnn_trainer
::
num_computational
_layers
<<
" but"
<<
std
::
endl
;
sout
<<
"instead the file contains "
<<
num_
computational_layers
<<
" computational
layers."
<<
std
::
endl
;
sout
<<
"a different number of layers. We expected the number of
layers to be "
<<
dnn_trainer
::
num
_layers
<<
" but"
<<
std
::
endl
;
sout
<<
"instead the file contains "
<<
num_
layers
<<
"
layers."
<<
std
::
endl
;
throw
serialization_error
(
sout
.
str
());
}
...
...
@@ -551,7 +655,7 @@ namespace dlib
deserialize
(
item
.
mini_batch_size
,
in
);
deserialize
(
item
.
verbose
,
in
);
deserialize
(
item
.
net
,
in
);
deserialize
(
item
.
solvers
,
in
);
deserialize
(
item
.
devices
[
0
]
->
solvers
,
in
);
deserialize
(
dtemp
,
in
);
item
.
step_size
=
dtemp
;
deserialize
(
item
.
min_step_size
,
in
);
deserialize
(
ltemp
,
in
);
item
.
iter_without_progress_thresh
=
ltemp
;
...
...
@@ -560,6 +664,21 @@ namespace dlib
deserialize
(
item
.
epoch_iteration
,
in
);
deserialize
(
item
.
epoch_pos
,
in
);
deserialize
(
item
.
train_one_step_calls
,
in
);
if
(
item
.
devices
.
size
()
>
1
)
{
const
auto
prev_dev
=
dlib
::
cuda
::
get_device
();
// initialize all the other device networks and solver objects
for
(
size_t
i
=
1
;
i
<
item
.
devices
.
size
();
++
i
)
{
// Switch to this device so that any tensor objects that get allocated when
// we copy this stuff happen on this device.
dlib
::
cuda
::
set_device
(
item
.
devices
[
i
]
->
device_id
);
item
.
devices
[
i
]
->
solvers
=
item
.
devices
[
0
]
->
solvers
;
item
.
devices
[
i
]
->
net
=
item
.
devices
[
0
]
->
net
;
}
dlib
::
cuda
::
set_device
(
prev_dev
);
}
}
void
sync_to_disk
(
bool
do_it_now
=
false
...
...
@@ -594,16 +713,96 @@ namespace dlib
}
struct
clone_net
{};
// per device state. All the containers have the same number of objects in them.
struct
device_data
{
device_data
(
int
device_id_
,
net_type
&
net_
,
const
solver_type
&
solver_
)
:
device_id
(
device_id_
),
net
(
net_
),
solvers
(
num_computational_layers
,
solver_
)
{}
device_data
(
int
device_id_
,
net_type
&
net_
,
const
solver_type
&
solver_
,
clone_net
)
:
device_id
(
device_id_
),
net_copy
(
std
::
make_shared
<
net_type
>
(
net_
)),
net
(
*
net_copy
),
solvers
(
num_computational_layers
,
solver_
)
{}
int
device_id
;
std
::
shared_ptr
<
net_type
>
net_copy
;
net_type
&
net
;
std
::
vector
<
solver_type
>
solvers
;
};
template
<
typename
data_iterator
,
typename
label_iterator
>
void
send_job
(
data_iterator
dbegin
,
data_iterator
dend
,
label_iterator
lbegin
)
{
size_t
num
=
std
::
distance
(
dbegin
,
dend
);
size_t
devs
=
devices
.
size
();
job
.
t
.
resize
(
devs
);
job
.
labels
.
resize
(
devs
);
job
.
have_data
.
resize
(
devs
);
// chop the data into devs blocks, each of about block_size elements.
size_t
block_size
=
(
num
+
devs
-
1
)
/
devs
;
const
auto
prev_dev
=
dlib
::
cuda
::
get_device
();
for
(
size_t
i
=
0
;
i
<
devs
;
++
i
)
{
dlib
::
cuda
::
set_device
(
devices
[
i
]
->
device_id
);
size_t
start
=
i
*
block_size
;
size_t
stop
=
std
::
min
(
num
,
start
+
block_size
);
if
(
start
<
stop
)
{
devices
[
i
]
->
net
.
to_tensor
(
dbegin
+
start
,
dbegin
+
stop
,
job
.
t
[
i
]);
job
.
labels
[
i
].
assign
(
lbegin
+
start
,
lbegin
+
stop
);
job
.
have_data
[
i
]
=
true
;
}
else
{
job
.
have_data
[
i
]
=
false
;
}
}
dlib
::
cuda
::
set_device
(
prev_dev
);
job_pipe
.
enqueue
(
job
);
}
template
<
typename
data_iterator
>
void
send_job
(
data_iterator
dbegin
,
data_iterator
dend
)
{
typename
std
::
vector
<
label_type
>::
iterator
nothing
;
send_job
(
dbegin
,
dend
,
nothing
);
}
std
::
vector
<
std
::
shared_ptr
<
device_data
>>
devices
;
dlib
::
pipe
<
job_t
>
job_pipe
;
job_t
job
;
running_stats
<
double
>
rs
;
std
::
deque
<
double
>
previous_loss_values
;
unsigned
long
max_num_epochs
;
size_t
mini_batch_size
;
bool
verbose
;
int
cuda_device_id
;
net_type
&
net
;
std
::
vector
<
solver_type
>
solvers
;
std
::
atomic
<
double
>
step_size
;
double
min_step_size
;
std
::
atomic
<
unsigned
long
>
iter_without_progress_thresh
;
...
...
@@ -618,9 +817,8 @@ namespace dlib
unsigned
long
long
train_one_step_calls
;
unsigned
long
gradient_check_budget
;
// The job object is not logically part of the state of this object. It is here
// only to avoid reallocating it over and over.
job_t
job
;
};
// ----------------------------------------------------------------------------------------
...
...
dlib/dnn/trainer_abstract.h
View file @
ca11d108
...
...
@@ -48,12 +48,17 @@ namespace dlib
dnn_trainer
()
=
delete
;
dnn_trainer
(
const
dnn_trainer
&
)
=
delete
;
dnn_trainer
&
operator
=
(
const
dnn_trainer
&
)
=
delete
;
dnn_trainer
(
net_type
&
net
,
const
solver_type
&
solver
=
solver_type
()
const
solver_type
&
solver
=
solver_type
(),
const
std
::
vector
<
int
>&
cuda_extra_devices
=
{}
);
/*!
requires
- for all valid i:
- 0 <= cuda_extra_devices[i] < dlib::cuda::get_num_devices()
ensures
- &#get_net() == &net
(i.e. The dnn_trainer holds a reference to net, it does not copy it.
...
...
@@ -67,6 +72,13 @@ namespace dlib
- #get_min_step_size() == 1e-3
- #get_iterations_without_progress_threshold() == 2000
- #get_step_size_shrink() == 0.1
- if (cuda_extra_devices.size() > 0) then
- This object will use multiple graphics cards to run the learning
algorithms. In particular, it will always use whatever device is
currently selected on the calling thread (the device indicated by
cudaGetDevice()). In addition, you can ask to use additional
devices, which you do by putting their device numbers into
cuda_extra_devices.
!*/
net_type
&
get_net
(
...
...
@@ -82,15 +94,6 @@ namespace dlib
stopped touching the net.
!*/
void
set_solver
(
const
solver_type
&
solver
);
/*!
ensures
- assigns solver to all the solvers in this object. I.e. solver will be
assigned to each element in get_solvers().
!*/
const
std
::
vector
<
solver_type
>&
get_solvers
(
)
const
;
/*!
...
...
@@ -101,22 +104,6 @@ namespace dlib
get_solvers()[1], and so on.
!*/
std
::
vector
<
solver_type
>&
get_solvers
(
);
/*!
ensures
- returns the solvers used to optimize each layer of the neural network
get_net(). In particular, the first layer's solver is
get_solvers()[0], the second layer's solver is
get_solvers()[1], and so on.
- It should be noted that you should never change the number of elements in
the vector returned by get_solvers() (i.e. don't do something that changes
get_solvers().size()). It will be set to net_type::num_computational_layers
by this object and you should leave it at that. The non-const version of
get_solvers() is provided only so you can tweak the parameters of a
particular solver.
!*/
unsigned
long
get_mini_batch_size
(
)
const
;
/*!
...
...
@@ -289,6 +276,7 @@ namespace dlib
/*!
requires
- data.size() == labels.size()
- data.size() > 0
- net_type uses a supervised loss.
i.e. net_type::label_type != no_label_type.
ensures
...
...
@@ -314,6 +302,7 @@ namespace dlib
);
/*!
requires
- data.size() > 0
- net_type uses an unsupervised loss.
i.e. net_type::label_type == no_label_type.
ensures
...
...
@@ -341,6 +330,7 @@ namespace dlib
/*!
requires
- data.size() == labels.size()
- data.size() > 0
- net_type uses a supervised loss.
i.e. net_type::label_type != no_label_type.
ensures
...
...
@@ -363,6 +353,7 @@ namespace dlib
);
/*!
requires
- data.size() > 0
- net_type uses an unsupervised loss.
i.e. net_type::label_type == no_label_type.
ensures
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment