Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
d9d6fa12
Commit
d9d6fa12
authored
May 02, 2014
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added the ability to set a previously trained function as a prior to the
svm_multiclass_linear_trainer.
parent
a7047b35
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
164 additions
and
8 deletions
+164
-8
svm_multiclass_linear_trainer.h
dlib/svm/svm_multiclass_linear_trainer.h
+73
-8
svm_multiclass_linear_trainer_abstract.h
dlib/svm/svm_multiclass_linear_trainer_abstract.h
+32
-0
svm_multiclass_linear.cpp
dlib/test/svm_multiclass_linear.cpp
+59
-0
No files found.
dlib/svm/svm_multiclass_linear_trainer.h
View file @
d9d6fa12
...
...
@@ -10,6 +10,7 @@
#include "../matrix.h"
#include "sparse_vector.h"
#include "function.h"
#include <algorithm>
namespace
dlib
{
...
...
@@ -46,13 +47,15 @@ namespace dlib
multiclass_svm_problem
(
const
std
::
vector
<
sample_type
>&
samples_
,
const
std
::
vector
<
label_type
>&
labels_
,
const
std
::
vector
<
label_type
>&
distinct_labels_
,
const
unsigned
long
dims_
,
const
unsigned
long
num_threads
)
:
structural_svm_problem_threaded
<
matrix_type
,
std
::
vector
<
std
::
pair
<
unsigned
long
,
typename
matrix_type
::
type
>
>
>
(
num_threads
),
samples
(
samples_
),
labels
(
labels_
),
distinct_labels
(
select_all_distinct_labels
(
labels_
)
),
dims
(
max_index_plus_one
(
samples_
)
+
1
)
// +1 for the bias
distinct_labels
(
distinct_labels_
),
dims
(
dims_
+
1
)
// +1 for the bias
{}
virtual
long
get_num_dimensions
(
...
...
@@ -151,7 +154,7 @@ namespace dlib
const
std
::
vector
<
sample_type
>&
samples
;
const
std
::
vector
<
label_type
>&
labels
;
const
std
::
vector
<
label_type
>
distinct_labels
;
const
std
::
vector
<
label_type
>
&
distinct_labels
;
const
long
dims
;
};
...
...
@@ -260,6 +263,7 @@ namespace dlib
)
{
learn_nonnegative_weights
=
value
;
prior
=
trained_function_type
();
}
void
set_c
(
...
...
@@ -283,6 +287,20 @@ namespace dlib
return
C
;
}
void
set_prior
(
const
trained_function_type
&
prior_
)
{
prior
=
prior_
;
learn_nonnegative_weights
=
false
;
}
bool
has_prior
(
)
const
{
return
prior
.
labels
.
size
()
!=
0
;
}
trained_function_type
train
(
const
std
::
vector
<
sample_type
>&
all_samples
,
const
std
::
vector
<
label_type
>&
all_labels
...
...
@@ -306,9 +324,33 @@ namespace dlib
<<
"
\n\t
all_labels.size(): "
<<
all_labels
.
size
()
);
trained_function_type
df
;
df
.
labels
=
select_all_distinct_labels
(
all_labels
);
if
(
has_prior
())
{
df
.
labels
.
insert
(
df
.
labels
.
end
(),
prior
.
labels
.
begin
(),
prior
.
labels
.
end
());
df
.
labels
=
select_all_distinct_labels
(
df
.
labels
);
}
const
long
input_sample_dimensionality
=
max_index_plus_one
(
all_samples
);
// If the samples are sparse then the right thing to do is to take the max
// dimensionality between the prior and the new samples. But if the samples
// are dense vectors then they definitely all have to have exactly the same
// dimensionality.
const
long
dims
=
std
::
max
(
df
.
weights
.
nc
(),
input_sample_dimensionality
);
if
(
is_matrix
<
sample_type
>::
value
&&
has_prior
())
{
DLIB_ASSERT
(
input_sample_dimensionality
==
prior
.
weights
.
nc
(),
"
\t
trained_function_type svm_multiclass_linear_trainer::train(all_samples,all_labels)"
<<
"
\n\t
The training samples given to this function are not the same kind of training "
<<
"
\n\t
samples used to create the prior."
<<
"
\n\t
input_sample_dimensionality: "
<<
input_sample_dimensionality
<<
"
\n\t
prior.weights.nc(): "
<<
prior
.
weights
.
nc
()
);
}
typedef
matrix
<
scalar_type
,
0
,
1
>
w_type
;
w_type
weights
;
multiclass_svm_problem
<
w_type
,
sample_type
,
label_type
>
problem
(
all_samples
,
all_labels
,
num_threads
);
multiclass_svm_problem
<
w_type
,
sample_type
,
label_type
>
problem
(
all_samples
,
all_labels
,
df
.
labels
,
dims
,
num_threads
);
if
(
verbose
)
problem
.
be_verbose
();
...
...
@@ -322,12 +364,33 @@ namespace dlib
num_nonnegative
=
problem
.
get_num_dimensions
();
}
svm_objective
=
solver
(
problem
,
weights
,
num_nonnegative
);
if
(
!
has_prior
())
{
svm_objective
=
solver
(
problem
,
weights
,
num_nonnegative
);
}
else
{
matrix
<
scalar_type
>
temp
(
df
.
labels
.
size
(),
dims
);
w_type
b
(
df
.
labels
.
size
());
temp
=
0
;
b
=
0
;
// Copy the prior into the temp and b matrices. We have to do this row
// by row copy because the new training data might have new labels we
// haven't seen before and therefore the sizes of these matrices could be
// different.
for
(
unsigned
long
i
=
0
;
i
<
prior
.
labels
.
size
();
++
i
)
{
const
long
r
=
std
::
find
(
df
.
labels
.
begin
(),
df
.
labels
.
end
(),
prior
.
labels
[
i
])
-
df
.
labels
.
begin
();
set_rowm
(
temp
,
r
)
=
rowm
(
prior
.
weights
,
i
);
b
(
r
)
=
prior
.
b
(
i
);
}
const
w_type
prior_vect
=
reshape_to_column_vector
(
join_rows
(
temp
,
b
));
svm_objective
=
solver
(
problem
,
weights
,
prior_vect
);
}
trained_function_type
df
;
const
long
dims
=
max_index_plus_one
(
all_samples
);
df
.
labels
=
select_all_distinct_labels
(
all_labels
);
df
.
weights
=
colm
(
reshape
(
weights
,
df
.
labels
.
size
(),
dims
+
1
),
range
(
0
,
dims
-
1
));
df
.
b
=
colm
(
reshape
(
weights
,
df
.
labels
.
size
(),
dims
+
1
),
dims
);
return
df
;
...
...
@@ -341,6 +404,8 @@ namespace dlib
bool
verbose
;
oca
solver
;
bool
learn_nonnegative_weights
;
trained_function_type
prior
;
};
// ----------------------------------------------------------------------------------------
...
...
dlib/svm/svm_multiclass_linear_trainer_abstract.h
View file @
d9d6fa12
...
...
@@ -37,6 +37,7 @@ namespace dlib
- get_c() == 1
- this object will not be verbose unless be_verbose() is called
- #get_oca() == oca() (i.e. an instance of oca with default parameters)
- has_prior() == false
WHAT THIS OBJECT REPRESENTS
This object represents a tool for training a multiclass support
...
...
@@ -176,6 +177,29 @@ namespace dlib
- #learns_nonnegative_weights() == value
!*/
void
set_prior
(
const
trained_function_type
&
prior
);
/*!
ensures
- #has_prior() == true
- #learns_nonnegative_weights() == false
!*/
bool
has_prior
(
)
const
/*!
ensures
- returns true if a prior has been set and false otherwise. Having a prior
set means that you have called set_prior() and supplied a previously
trained function as a reference. In this case, any call to train() will
try to learn a function that matches the behavior of the prior as close
as possible but also fits the supplied training data. In more technical
detail, having a prior means we replace the ||w||^2 regularizer with one
of the form ||w-prior||^2 where w is the set of parameters for a learned
function.
!*/
trained_function_type
train
(
const
std
::
vector
<
sample_type
>&
all_samples
,
const
std
::
vector
<
label_type
>&
all_labels
...
...
@@ -183,6 +207,10 @@ namespace dlib
/*!
requires
- is_learning_problem(all_samples, all_labels)
- All the vectors in all_samples must have the same dimensionality.
- if (has_prior()) then
- The vectors in all_samples must have the same dimensionality as the
vectors used to train the prior given to set_prior().
ensures
- trains a multiclass SVM to solve the given multiclass classification problem.
- returns a multiclass_linear_decision_function F with the following properties:
...
...
@@ -200,6 +228,10 @@ namespace dlib
/*!
requires
- is_learning_problem(all_samples, all_labels)
- All the vectors in all_samples must have the same dimensionality.
- if (has_prior()) then
- The vectors in all_samples must have the same dimensionality as the
vectors used to train the prior given to set_prior().
ensures
- trains a multiclass SVM to solve the given multiclass classification problem.
- returns a multiclass_linear_decision_function F with the following properties:
...
...
dlib/test/svm_multiclass_linear.cpp
View file @
d9d6fa12
...
...
@@ -35,6 +35,63 @@ namespace
}
void
test_prior
()
{
print_spinner
();
typedef
matrix
<
double
,
4
,
1
>
sample_type
;
typedef
linear_kernel
<
sample_type
>
kernel_type
;
std
::
vector
<
sample_type
>
samples
;
std
::
vector
<
int
>
labels
;
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
if
(
i
==
2
)
++
i
;
for
(
int
iter
=
0
;
iter
<
5
;
++
iter
)
{
sample_type
samp
;
samp
=
0
;
samp
(
i
)
=
1
;
samples
.
push_back
(
samp
);
labels
.
push_back
(
i
);
}
}
svm_multiclass_linear_trainer
<
kernel_type
,
int
>
trainer
;
multiclass_linear_decision_function
<
kernel_type
,
int
>
df
=
trainer
.
train
(
samples
,
labels
);
//cout << "test: \n" << test_multiclass_decision_function(df, samples, labels) << endl;
//cout << df.weights << endl;
//cout << df.b << endl;
std
::
vector
<
sample_type
>
samples2
;
std
::
vector
<
int
>
labels2
;
int
i
=
2
;
for
(
int
iter
=
0
;
iter
<
5
;
++
iter
)
{
sample_type
samp
;
samp
=
0
;
samp
(
i
)
=
1
;
samples2
.
push_back
(
samp
);
labels2
.
push_back
(
i
);
samples
.
push_back
(
samp
);
labels
.
push_back
(
i
);
}
trainer
.
set_prior
(
df
);
trainer
.
set_c
(
0.1
);
df
=
trainer
.
train
(
samples2
,
labels2
);
matrix
<
double
>
res
=
test_multiclass_decision_function
(
df
,
samples
,
labels
);
dlog
<<
LINFO
<<
"test:
\n
"
<<
res
;
dlog
<<
LINFO
<<
df
.
weights
;
dlog
<<
LINFO
<<
df
.
b
;
DLIB_TEST
((
unsigned
int
)
sum
(
diag
(
res
))
==
samples
.
size
());
}
template
<
typename
sample_type
>
void
run_test
()
{
...
...
@@ -99,6 +156,8 @@ namespace
run_test
<
std
::
map
<
unsigned
int
,
float
>
>
();
run_test
<
std
::
vector
<
std
::
pair
<
unsigned
int
,
float
>
>
>
();
run_test
<
std
::
vector
<
std
::
pair
<
unsigned
long
,
double
>
>
>
();
test_prior
();
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment