Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
754610f2
Commit
754610f2
authored
May 05, 2014
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added the option to set a prior to svm_rank_trainer.
parent
461abe65
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
129 additions
and
1 deletion
+129
-1
svm_rank_trainer.h
dlib/svm/svm_rank_trainer.h
+51
-1
svm_rank_trainer_abstract.h
dlib/svm/svm_rank_trainer_abstract.h
+43
-0
ranking.cpp
dlib/test/ranking.cpp
+35
-0
No files found.
dlib/svm/svm_rank_trainer.h
View file @
754610f2
...
@@ -297,6 +297,8 @@ namespace dlib
...
@@ -297,6 +297,8 @@ namespace dlib
)
)
{
{
last_weight_1
=
should_last_weight_be_1
;
last_weight_1
=
should_last_weight_be_1
;
if
(
last_weight_1
)
prior
.
set_size
(
0
);
}
}
void
set_oca
(
void
set_oca
(
...
@@ -326,6 +328,33 @@ namespace dlib
...
@@ -326,6 +328,33 @@ namespace dlib
)
)
{
{
learn_nonnegative_weights
=
value
;
learn_nonnegative_weights
=
value
;
if
(
learn_nonnegative_weights
)
prior
.
set_size
(
0
);
}
void
set_prior
(
const
trained_function_type
&
prior_
)
{
// make sure requires clause is not broken
DLIB_ASSERT
(
prior_
.
basis_vectors
.
size
()
==
1
&&
prior_
.
alpha
(
0
)
==
1
,
"
\t
void svm_rank_trainer::set_prior()"
<<
"
\n\t
The supplied prior could not have been created by this object's train() method."
<<
"
\n\t
prior_.basis_vectors.size(): "
<<
prior_
.
basis_vectors
.
size
()
<<
"
\n\t
prior_.alpha(0): "
<<
prior_
.
alpha
(
0
)
<<
"
\n\t
this: "
<<
this
);
prior
=
prior_
.
basis_vectors
(
0
);
learn_nonnegative_weights
=
false
;
last_weight_1
=
false
;
}
bool
has_prior
(
)
const
{
return
prior
.
size
()
!=
0
;
}
}
void
set_c
(
void
set_c
(
...
@@ -379,10 +408,30 @@ namespace dlib
...
@@ -379,10 +408,30 @@ namespace dlib
force_weight_1_idx
=
num_dims
-
1
;
force_weight_1_idx
=
num_dims
-
1
;
}
}
solver
(
make_oca_problem_ranking_svm
<
w_type
>
(
C
,
samples
,
verbose
,
eps
,
max_iterations
),
if
(
has_prior
())
{
if
(
is_matrix
<
sample_type
>::
value
)
{
// make sure requires clause is not broken
DLIB_CASSERT
(
num_dims
==
(
unsigned
long
)
prior
.
size
(),
"
\t
decision_function svm_rank_trainer::train(samples)"
<<
"
\n\t
The dimension of the training vectors must match the dimension of
\n
"
<<
"
\n\t
those used to create the prior."
<<
"
\n\t
num_dims: "
<<
num_dims
<<
"
\n\t
prior.size(): "
<<
prior
.
size
()
);
}
solver
(
make_oca_problem_ranking_svm
<
w_type
>
(
C
,
samples
,
verbose
,
eps
,
max_iterations
),
w
,
prior
);
}
else
{
solver
(
make_oca_problem_ranking_svm
<
w_type
>
(
C
,
samples
,
verbose
,
eps
,
max_iterations
),
w
,
w
,
num_nonnegative
,
num_nonnegative
,
force_weight_1_idx
);
force_weight_1_idx
);
}
// put the solution into a decision function and then return it
// put the solution into a decision function and then return it
...
@@ -415,6 +464,7 @@ namespace dlib
...
@@ -415,6 +464,7 @@ namespace dlib
unsigned
long
max_iterations
;
unsigned
long
max_iterations
;
bool
learn_nonnegative_weights
;
bool
learn_nonnegative_weights
;
bool
last_weight_1
;
bool
last_weight_1
;
matrix
<
scalar_type
,
0
,
1
>
prior
;
};
};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
...
...
dlib/svm/svm_rank_trainer_abstract.h
View file @
754610f2
...
@@ -58,6 +58,7 @@ namespace dlib
...
@@ -58,6 +58,7 @@ namespace dlib
- #get_max_iterations() == 10000
- #get_max_iterations() == 10000
- #learns_nonnegative_weights() == false
- #learns_nonnegative_weights() == false
- #forces_last_weight_to_1() == false
- #forces_last_weight_to_1() == false
- #has_prior() == false
!*/
!*/
explicit
svm_rank_trainer
(
explicit
svm_rank_trainer
(
...
@@ -76,6 +77,7 @@ namespace dlib
...
@@ -76,6 +77,7 @@ namespace dlib
- #get_max_iterations() == 10000
- #get_max_iterations() == 10000
- #learns_nonnegative_weights() == false
- #learns_nonnegative_weights() == false
- #forces_last_weight_to_1() == false
- #forces_last_weight_to_1() == false
- #has_prior() == false
!*/
!*/
void
set_epsilon
(
void
set_epsilon
(
...
@@ -146,6 +148,8 @@ namespace dlib
...
@@ -146,6 +148,8 @@ namespace dlib
/*!
/*!
ensures
ensures
- #forces_last_weight_to_1() == should_last_weight_be_1
- #forces_last_weight_to_1() == should_last_weight_be_1
- if (should_last_weight_be_1 == true) then
- #has_prior() == false
!*/
!*/
void
set_oca
(
void
set_oca
(
...
@@ -190,6 +194,39 @@ namespace dlib
...
@@ -190,6 +194,39 @@ namespace dlib
/*!
/*!
ensures
ensures
- #learns_nonnegative_weights() == value
- #learns_nonnegative_weights() == value
- if (value == true) then
- #has_prior() == false
!*/
void
set_prior
(
const
trained_function_type
&
prior
);
/*!
requires
- prior == a function produced by a call to this class's train() function.
Therefore, it must be the case that:
- prior.basis_vectors.size() == 1
- prior.alpha(0) == 1
ensures
- Subsequent calls to train() will try to learn a function similar to the
given prior.
- #has_prior() == true
- #learns_nonnegative_weights() == false
- #forces_last_weight_to_1() == false
!*/
bool
has_prior
(
)
const
/*!
ensures
- returns true if a prior has been set and false otherwise. Having a prior
set means that you have called set_prior() and supplied a previously
trained function as a reference. In this case, any call to train() will
try to learn a function that matches the behavior of the prior as close
as possible but also fits the supplied training data. In more technical
detail, having a prior means we replace the ||w||^2 regularizer with one
of the form ||w-prior||^2 where w is the set of parameters for a learned
function.
!*/
!*/
void
set_c
(
void
set_c
(
...
@@ -219,6 +256,9 @@ namespace dlib
...
@@ -219,6 +256,9 @@ namespace dlib
/*!
/*!
requires
requires
- is_ranking_problem(samples) == true
- is_ranking_problem(samples) == true
- if (has_prior()) then
- The vectors in samples must have the same dimensionality as the
vectors used to train the prior given to set_prior().
ensures
ensures
- trains a ranking support vector classifier given the training samples.
- trains a ranking support vector classifier given the training samples.
- returns a decision function F with the following properties:
- returns a decision function F with the following properties:
...
@@ -237,6 +277,9 @@ namespace dlib
...
@@ -237,6 +277,9 @@ namespace dlib
/*!
/*!
requires
requires
- is_ranking_problem(std::vector<ranking_pair<sample_type> >(1, sample)) == true
- is_ranking_problem(std::vector<ranking_pair<sample_type> >(1, sample)) == true
- if (has_prior()) then
- The vectors in samples must have the same dimensionality as the
vectors used to train the prior given to set_prior().
ensures
ensures
- This is just a convenience routine for calling the above train()
- This is just a convenience routine for calling the above train()
function. That is, it just copies sample into a std::vector object and
function. That is, it just copies sample into a std::vector object and
...
...
dlib/test/ranking.cpp
View file @
754610f2
...
@@ -73,6 +73,40 @@ namespace
...
@@ -73,6 +73,40 @@ namespace
}
}
}
}
// ----------------------------------------------------------------------------------------
void
run_prior_test
()
{
print_spinner
();
typedef
matrix
<
double
,
3
,
1
>
sample_type
;
typedef
linear_kernel
<
sample_type
>
kernel_type
;
svm_rank_trainer
<
kernel_type
>
trainer
;
ranking_pair
<
sample_type
>
data
;
sample_type
samp
;
samp
=
0
,
0
,
1
;
data
.
relevant
.
push_back
(
samp
);
samp
=
0
,
1
,
0
;
data
.
nonrelevant
.
push_back
(
samp
);
trainer
.
set_c
(
10
);
decision_function
<
kernel_type
>
df
=
trainer
.
train
(
data
);
trainer
.
set_prior
(
df
);
data
.
relevant
.
clear
();
data
.
nonrelevant
.
clear
();
samp
=
1
,
0
,
0
;
data
.
relevant
.
push_back
(
samp
);
samp
=
0
,
1
,
0
;
data
.
nonrelevant
.
push_back
(
samp
);
df
=
trainer
.
train
(
data
);
dlog
<<
LINFO
<<
trans
(
df
.
basis_vectors
(
0
));
DLIB_TEST
(
df
.
basis_vectors
(
0
)(
0
)
>
0
);
DLIB_TEST
(
df
.
basis_vectors
(
0
)(
1
)
<
0
);
DLIB_TEST
(
df
.
basis_vectors
(
0
)(
2
)
>
0
);
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
void
dotest1
()
void
dotest1
()
...
@@ -355,6 +389,7 @@ namespace
...
@@ -355,6 +389,7 @@ namespace
dotest_sparse_vectors
();
dotest_sparse_vectors
();
test_svmrank_weight_force_dense
<
true
>
();
test_svmrank_weight_force_dense
<
true
>
();
test_svmrank_weight_force_dense
<
false
>
();
test_svmrank_weight_force_dense
<
false
>
();
run_prior_test
();
}
}
}
a
;
}
a
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment