Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
160337da
Commit
160337da
authored
Nov 18, 2013
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Made the one_vs_one_trainer and one_vs_all_trainer objects multithreaded
so they can run each binary trainer on a different core.
parent
525f2a52
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
254 additions
and
76 deletions
+254
-76
svm.h
dlib/svm.h
+0
-2
cross_validate_multiclass_trainer.h
dlib/svm/cross_validate_multiclass_trainer.h
+0
-1
one_vs_all_trainer.h
dlib/svm/one_vs_all_trainer.h
+104
-34
one_vs_all_trainer_abstract.h
dlib/svm/one_vs_all_trainer_abstract.h
+22
-4
one_vs_one_trainer.h
dlib/svm/one_vs_one_trainer.h
+102
-29
one_vs_one_trainer_abstract.h
dlib/svm/one_vs_one_trainer_abstract.h
+22
-4
svm_threaded.h
dlib/svm_threaded.h
+2
-0
one_vs_all_trainer.cpp
dlib/test/one_vs_all_trainer.cpp
+1
-1
one_vs_one_trainer.cpp
dlib/test/one_vs_one_trainer.cpp
+1
-1
No files found.
dlib/svm.h
View file @
160337da
...
@@ -32,7 +32,6 @@
...
@@ -32,7 +32,6 @@
#include "svm/svr_trainer.h"
#include "svm/svr_trainer.h"
#include "svm/one_vs_one_decision_function.h"
#include "svm/one_vs_one_decision_function.h"
#include "svm/one_vs_one_trainer.h"
#include "svm/multiclass_tools.h"
#include "svm/multiclass_tools.h"
#include "svm/cross_validate_multiclass_trainer.h"
#include "svm/cross_validate_multiclass_trainer.h"
#include "svm/cross_validate_regression_trainer.h"
#include "svm/cross_validate_regression_trainer.h"
...
@@ -42,7 +41,6 @@
...
@@ -42,7 +41,6 @@
#include "svm/cross_validate_assignment_trainer.h"
#include "svm/cross_validate_assignment_trainer.h"
#include "svm/one_vs_all_decision_function.h"
#include "svm/one_vs_all_decision_function.h"
#include "svm/one_vs_all_trainer.h"
#include "svm/structural_svm_problem.h"
#include "svm/structural_svm_problem.h"
#include "svm/sequence_labeler.h"
#include "svm/sequence_labeler.h"
...
...
dlib/svm/cross_validate_multiclass_trainer.h
View file @
160337da
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
#include <vector>
#include <vector>
#include "../matrix.h"
#include "../matrix.h"
#include "one_vs_one_trainer.h"
#include "cross_validate_multiclass_trainer_abstract.h"
#include "cross_validate_multiclass_trainer_abstract.h"
#include <sstream>
#include <sstream>
...
...
dlib/svm/one_vs_all_trainer.h
View file @
160337da
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include "../any.h"
#include "../any.h"
#include <map>
#include <map>
#include <set>
#include <set>
#include "../threads.h"
namespace
dlib
namespace
dlib
{
{
...
@@ -39,7 +40,8 @@ namespace dlib
...
@@ -39,7 +40,8 @@ namespace dlib
one_vs_all_trainer
(
one_vs_all_trainer
(
)
:
)
:
verbose
(
false
)
verbose
(
false
),
num_threads
(
4
)
{}
{}
void
set_trainer
(
void
set_trainer
(
...
@@ -70,6 +72,19 @@ namespace dlib
...
@@ -70,6 +72,19 @@ namespace dlib
verbose
=
false
;
verbose
=
false
;
}
}
void
set_num_threads
(
unsigned
long
num
)
{
num_threads
=
num
;
}
unsigned
long
get_num_threads
(
)
const
{
return
num_threads
;
}
struct
invalid_label
:
public
dlib
::
error
struct
invalid_label
:
public
dlib
::
error
{
{
invalid_label
(
const
std
::
string
&
msg
,
const
label_type
&
l_
invalid_label
(
const
std
::
string
&
msg
,
const
label_type
&
l_
...
@@ -96,62 +111,117 @@ namespace dlib
...
@@ -96,62 +111,117 @@ namespace dlib
const
std
::
vector
<
label_type
>
distinct_labels
=
select_all_distinct_labels
(
all_labels
);
const
std
::
vector
<
label_type
>
distinct_labels
=
select_all_distinct_labels
(
all_labels
);
std
::
vector
<
scalar_type
>
labels
;
// make sure we have a trainer object for each of the label types.
typename
trained_function_type
::
binary_function_table
dfs
;
for
(
unsigned
long
i
=
0
;
i
<
distinct_labels
.
size
();
++
i
)
for
(
unsigned
long
i
=
0
;
i
<
distinct_labels
.
size
();
++
i
)
{
{
labels
.
clear
();
const
label_type
l
=
distinct_labels
[
i
];
const
label_type
l
=
distinct_labels
[
i
];
const
typename
binary_function_table
::
const_iterator
itr
=
trainers
.
find
(
l
);
// setup one of the one vs all training sets
if
(
itr
==
trainers
.
end
()
&&
default_trainer
.
is_empty
())
for
(
unsigned
long
k
=
0
;
k
<
all_samples
.
size
();
++
k
)
{
{
if
(
all_labels
[
k
]
==
l
)
std
::
ostringstream
sout
;
labels
.
push_back
(
+
1
);
sout
<<
"In one_vs_all_trainer, no trainer registered for the "
<<
l
<<
" label."
;
else
throw
invalid_label
(
sout
.
str
(),
l
);
labels
.
push_back
(
-
1
);
}
}
}
if
(
verbose
)
// now do the training
{
parallel_for_helper
helper
(
all_samples
,
all_labels
,
default_trainer
,
trainers
,
verbose
,
distinct_labels
);
std
::
cout
<<
"Training classifier for "
<<
l
<<
" vs. all"
<<
std
::
endl
;
parallel_for
(
num_threads
,
0
,
distinct_labels
.
size
(),
helper
,
500
);
}
// now train a binary classifier using the samples we selected
if
(
helper
.
error_message
.
size
()
!=
0
)
const
typename
binary_function_table
::
const_iterator
itr
=
trainers
.
find
(
l
);
{
throw
dlib
::
error
(
"binary trainer threw while training one vs. all classifier. Error was: "
+
helper
.
error_message
);
}
return
trained_function_type
(
helper
.
dfs
);
}
if
(
itr
!=
trainers
.
end
())
private
:
{
dfs
[
l
]
=
itr
->
second
.
train
(
all_samples
,
labels
);
typedef
std
::
map
<
label_type
,
any_trainer
>
binary_function_table
;
}
struct
parallel_for_helper
else
if
(
default_trainer
.
is_empty
()
==
false
)
{
parallel_for_helper
(
const
std
::
vector
<
sample_type
>&
all_samples_
,
const
std
::
vector
<
label_type
>&
all_labels_
,
const
any_trainer
&
default_trainer_
,
const
binary_function_table
&
trainers_
,
const
bool
verbose_
,
const
std
::
vector
<
label_type
>&
distinct_labels_
)
:
all_samples
(
all_samples_
),
all_labels
(
all_labels_
),
default_trainer
(
default_trainer_
),
trainers
(
trainers_
),
verbose
(
verbose_
),
distinct_labels
(
distinct_labels_
)
{}
void
operator
()(
long
i
)
const
{
try
{
{
dfs
[
l
]
=
default_trainer
.
train
(
all_samples
,
labels
);
std
::
vector
<
scalar_type
>
labels
;
const
label_type
l
=
distinct_labels
[
i
];
// setup one of the one vs all training sets
for
(
unsigned
long
k
=
0
;
k
<
all_samples
.
size
();
++
k
)
{
if
(
all_labels
[
k
]
==
l
)
labels
.
push_back
(
+
1
);
else
labels
.
push_back
(
-
1
);
}
if
(
verbose
)
{
auto_mutex
lock
(
class_mutex
);
std
::
cout
<<
"Training classifier for "
<<
l
<<
" vs. all"
<<
std
::
endl
;
}
any_trainer
trainer
;
// now train a binary classifier using the samples we selected
{
auto_mutex
lock
(
class_mutex
);
const
typename
binary_function_table
::
const_iterator
itr
=
trainers
.
find
(
l
);
if
(
itr
!=
trainers
.
end
())
trainer
=
itr
->
second
;
else
trainer
=
default_trainer
;
}
any_decision_function
<
sample_type
,
scalar_type
>
binary_df
=
trainer
.
train
(
all_samples
,
labels
);
auto_mutex
lock
(
class_mutex
);
dfs
[
l
]
=
binary_df
;
}
}
else
catch
(
std
::
exception
&
e
)
{
{
std
::
ostringstream
sout
;
auto_mutex
lock
(
class_mutex
);
sout
<<
"In one_vs_all_trainer, no trainer registered for the "
<<
l
<<
" label."
;
error_message
=
e
.
what
();
throw
invalid_label
(
sout
.
str
(),
l
);
}
}
}
}
return
trained_function_type
(
dfs
);
mutable
typename
trained_function_type
::
binary_function_table
dfs
;
}
mutex
class_mutex
;
mutable
std
::
string
error_message
;
private
:
const
std
::
vector
<
sample_type
>&
all_samples
;
const
std
::
vector
<
label_type
>&
all_labels
;
const
any_trainer
&
default_trainer
;
const
binary_function_table
&
trainers
;
const
bool
verbose
;
const
std
::
vector
<
label_type
>&
distinct_labels
;
};
any_trainer
default_trainer
;
any_trainer
default_trainer
;
typedef
std
::
map
<
label_type
,
any_trainer
>
binary_function_table
;
binary_function_table
trainers
;
binary_function_table
trainers
;
bool
verbose
;
bool
verbose
;
unsigned
long
num_threads
;
};
};
...
...
dlib/svm/one_vs_all_trainer_abstract.h
View file @
160337da
...
@@ -55,10 +55,11 @@ namespace dlib
...
@@ -55,10 +55,11 @@ namespace dlib
);
);
/*!
/*!
ensures
ensures
- this object is properly initialized
- This object is properly initialized.
- this object will not be verbose unless be_verbose() is called
- This object will not be verbose unless be_verbose() is called.
- no binary trainers are associated with *this. I.e. you have to
- No binary trainers are associated with *this. I.e. you have to
call set_trainer() before calling train()
call set_trainer() before calling train().
- #get_num_threads() == 4
!*/
!*/
void
set_trainer
(
void
set_trainer
(
...
@@ -96,6 +97,23 @@ namespace dlib
...
@@ -96,6 +97,23 @@ namespace dlib
- this object will not print anything to standard out
- this object will not print anything to standard out
!*/
!*/
void
set_num_threads
(
unsigned
long
num
);
/*!
ensures
- #get_num_threads() == num
!*/
unsigned
long
get_num_threads
(
)
const
;
/*!
ensures
- returns the number of threads used during training. You should
usually set this equal to the number of processing cores on your
machine.
!*/
struct
invalid_label
:
public
dlib
::
error
struct
invalid_label
:
public
dlib
::
error
{
{
/*!
/*!
...
...
dlib/svm/one_vs_one_trainer.h
View file @
160337da
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "../any.h"
#include "../any.h"
#include <map>
#include <map>
#include <set>
#include <set>
#include "../threads.h"
namespace
dlib
namespace
dlib
{
{
...
@@ -40,7 +41,8 @@ namespace dlib
...
@@ -40,7 +41,8 @@ namespace dlib
one_vs_one_trainer
(
one_vs_one_trainer
(
)
:
)
:
verbose
(
false
)
verbose
(
false
),
num_threads
(
4
)
{}
{}
void
set_trainer
(
void
set_trainer
(
...
@@ -72,6 +74,19 @@ namespace dlib
...
@@ -72,6 +74,19 @@ namespace dlib
verbose
=
false
;
verbose
=
false
;
}
}
void
set_num_threads
(
unsigned
long
num
)
{
num_threads
=
num
;
}
unsigned
long
get_num_threads
(
)
const
{
return
num_threads
;
}
struct
invalid_label
:
public
dlib
::
error
struct
invalid_label
:
public
dlib
::
error
{
{
invalid_label
(
const
std
::
string
&
msg
,
const
label_type
&
l1_
,
const
label_type
&
l2_
invalid_label
(
const
std
::
string
&
msg
,
const
label_type
&
l1_
,
const
label_type
&
l2_
...
@@ -98,20 +113,70 @@ namespace dlib
...
@@ -98,20 +113,70 @@ namespace dlib
const
std
::
vector
<
label_type
>
distinct_labels
=
select_all_distinct_labels
(
all_labels
);
const
std
::
vector
<
label_type
>
distinct_labels
=
select_all_distinct_labels
(
all_labels
);
std
::
vector
<
sample_type
>
samples
;
std
::
vector
<
scalar_type
>
labels
;
typename
trained_function_type
::
binary_function_table
dfs
;
// fill pairs with all the pairs of labels.
std
::
vector
<
unordered_pair
<
label_type
>
>
pairs
;
for
(
unsigned
long
i
=
0
;
i
<
distinct_labels
.
size
();
++
i
)
for
(
unsigned
long
i
=
0
;
i
<
distinct_labels
.
size
();
++
i
)
{
{
for
(
unsigned
long
j
=
i
+
1
;
j
<
distinct_labels
.
size
();
++
j
)
for
(
unsigned
long
j
=
i
+
1
;
j
<
distinct_labels
.
size
();
++
j
)
{
{
samples
.
clear
();
pairs
.
push_back
(
unordered_pair
<
label_type
>
(
distinct_labels
[
i
],
distinct_labels
[
j
]));
labels
.
clear
();
const
unordered_pair
<
label_type
>
p
(
distinct_labels
[
i
],
distinct_labels
[
j
]);
// make sure we have a trainer for this pair
const
typename
binary_function_table
::
const_iterator
itr
=
trainers
.
find
(
pairs
.
back
());
if
(
itr
==
trainers
.
end
()
&&
default_trainer
.
is_empty
())
{
std
::
ostringstream
sout
;
sout
<<
"In one_vs_one_trainer, no trainer registered for the ("
<<
pairs
.
back
().
first
<<
", "
<<
pairs
.
back
().
second
<<
") label pair."
;
throw
invalid_label
(
sout
.
str
(),
pairs
.
back
().
first
,
pairs
.
back
().
second
);
}
}
}
// Now train on all the label pairs.
parallel_for_helper
helper
(
all_samples
,
all_labels
,
default_trainer
,
trainers
,
verbose
,
pairs
);
parallel_for
(
num_threads
,
0
,
pairs
.
size
(),
helper
,
500
);
if
(
helper
.
error_message
.
size
()
!=
0
)
{
throw
dlib
::
error
(
"binary trainer threw while training one vs. one classifier. Error was: "
+
helper
.
error_message
);
}
return
trained_function_type
(
helper
.
dfs
);
}
private
:
typedef
std
::
map
<
unordered_pair
<
label_type
>
,
any_trainer
>
binary_function_table
;
struct
parallel_for_helper
{
parallel_for_helper
(
const
std
::
vector
<
sample_type
>&
all_samples_
,
const
std
::
vector
<
label_type
>&
all_labels_
,
const
any_trainer
&
default_trainer_
,
const
binary_function_table
&
trainers_
,
const
bool
verbose_
,
const
std
::
vector
<
unordered_pair
<
label_type
>
>&
pairs_
)
:
all_samples
(
all_samples_
),
all_labels
(
all_labels_
),
default_trainer
(
default_trainer_
),
trainers
(
trainers_
),
verbose
(
verbose_
),
pairs
(
pairs_
)
{}
void
operator
()(
long
i
)
const
{
try
{
std
::
vector
<
sample_type
>
samples
;
std
::
vector
<
scalar_type
>
labels
;
const
unordered_pair
<
label_type
>
p
=
pairs
[
i
];
// pick out the samples corresponding to these two classes
// pick out the samples corresponding to these two classes
for
(
unsigned
long
k
=
0
;
k
<
all_samples
.
size
();
++
k
)
for
(
unsigned
long
k
=
0
;
k
<
all_samples
.
size
();
++
k
)
...
@@ -128,43 +193,51 @@ namespace dlib
...
@@ -128,43 +193,51 @@ namespace dlib
}
}
}
}
if
(
verbose
)
if
(
verbose
)
{
{
auto_mutex
lock
(
class_mutex
);
std
::
cout
<<
"Training classifier for "
<<
p
.
first
<<
" vs. "
<<
p
.
second
<<
std
::
endl
;
std
::
cout
<<
"Training classifier for "
<<
p
.
first
<<
" vs. "
<<
p
.
second
<<
std
::
endl
;
}
}
any_trainer
trainer
;
// now train a binary classifier using the samples we selected
// now train a binary classifier using the samples we selected
{
auto_mutex
lock
(
class_mutex
);
const
typename
binary_function_table
::
const_iterator
itr
=
trainers
.
find
(
p
);
const
typename
binary_function_table
::
const_iterator
itr
=
trainers
.
find
(
p
);
if
(
itr
!=
trainers
.
end
())
if
(
itr
!=
trainers
.
end
())
{
trainer
=
itr
->
second
;
dfs
[
p
]
=
itr
->
second
.
train
(
samples
,
labels
);
else
}
trainer
=
default_trainer
;
else
if
(
default_trainer
.
is_empty
()
==
false
)
{
dfs
[
p
]
=
default_trainer
.
train
(
samples
,
labels
);
}
else
{
std
::
ostringstream
sout
;
sout
<<
"In one_vs_one_trainer, no trainer registered for the ("
<<
p
.
first
<<
", "
<<
p
.
second
<<
") label pair."
;
throw
invalid_label
(
sout
.
str
(),
p
.
first
,
p
.
second
);
}
}
any_decision_function
<
sample_type
,
scalar_type
>
binary_df
=
trainer
.
train
(
samples
,
labels
);
auto_mutex
lock
(
class_mutex
);
dfs
[
p
]
=
binary_df
;
}
catch
(
std
::
exception
&
e
)
{
auto_mutex
lock
(
class_mutex
);
error_message
=
e
.
what
();
}
}
}
}
return
trained_function_type
(
dfs
);
mutable
typename
trained_function_type
::
binary_function_table
dfs
;
}
mutex
class_mutex
;
mutable
std
::
string
error_message
;
private
:
const
std
::
vector
<
sample_type
>&
all_samples
;
const
std
::
vector
<
label_type
>&
all_labels
;
const
any_trainer
&
default_trainer
;
const
binary_function_table
&
trainers
;
const
bool
verbose
;
const
std
::
vector
<
unordered_pair
<
label_type
>
>&
pairs
;
};
any_trainer
default_trainer
;
any_trainer
default_trainer
;
typedef
std
::
map
<
unordered_pair
<
label_type
>
,
any_trainer
>
binary_function_table
;
binary_function_table
trainers
;
binary_function_table
trainers
;
bool
verbose
;
bool
verbose
;
unsigned
long
num_threads
;
};
};
...
...
dlib/svm/one_vs_one_trainer_abstract.h
View file @
160337da
...
@@ -55,10 +55,11 @@ namespace dlib
...
@@ -55,10 +55,11 @@ namespace dlib
);
);
/*!
/*!
ensures
ensures
- this object is properly initialized
- This object is properly initialized
- this object will not be verbose unless be_verbose() is called
- This object will not be verbose unless be_verbose() is called.
- no binary trainers are associated with *this. I.e. you have to
- No binary trainers are associated with *this. I.e. you have to
call set_trainer() before calling train()
call set_trainer() before calling train().
- #get_num_threads() == 4
!*/
!*/
void
set_trainer
(
void
set_trainer
(
...
@@ -99,6 +100,23 @@ namespace dlib
...
@@ -99,6 +100,23 @@ namespace dlib
- this object will not print anything to standard out
- this object will not print anything to standard out
!*/
!*/
void
set_num_threads
(
unsigned
long
num
);
/*!
ensures
- #get_num_threads() == num
!*/
unsigned
long
get_num_threads
(
)
const
;
/*!
ensures
- returns the number of threads used during training. You should
usually set this equal to the number of processing cores on your
machine.
!*/
struct
invalid_label
:
public
dlib
::
error
struct
invalid_label
:
public
dlib
::
error
{
{
/*!
/*!
...
...
dlib/svm_threaded.h
View file @
160337da
...
@@ -19,6 +19,8 @@
...
@@ -19,6 +19,8 @@
#include "svm/structural_graph_labeling_trainer.h"
#include "svm/structural_graph_labeling_trainer.h"
#include "svm/cross_validate_graph_labeling_trainer.h"
#include "svm/cross_validate_graph_labeling_trainer.h"
#include "svm/svm_multiclass_linear_trainer.h"
#include "svm/svm_multiclass_linear_trainer.h"
#include "svm/one_vs_one_trainer.h"
#include "svm/one_vs_all_trainer.h"
#endif // DLIB_SVm_THREADED_HEADER
#endif // DLIB_SVm_THREADED_HEADER
...
...
dlib/test/one_vs_all_trainer.cpp
View file @
160337da
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
// License: Boost Software License See LICENSE.txt for the full license.
// License: Boost Software License See LICENSE.txt for the full license.
#include "tester.h"
#include "tester.h"
#include <dlib/svm.h>
#include <dlib/svm
_threaded
.h>
#include <vector>
#include <vector>
#include <sstream>
#include <sstream>
...
...
dlib/test/one_vs_one_trainer.cpp
View file @
160337da
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
// License: Boost Software License See LICENSE.txt for the full license.
// License: Boost Software License See LICENSE.txt for the full license.
#include "tester.h"
#include "tester.h"
#include <dlib/svm.h>
#include <dlib/svm
_threaded
.h>
#include <dlib/statistics.h>
#include <dlib/statistics.h>
#include <vector>
#include <vector>
#include <sstream>
#include <sstream>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment