Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
0e67a4e3
Commit
0e67a4e3
authored
Nov 23, 2012
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fleshed out the spec and cleaned up a few minor things.
parent
845391d0
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
99 additions
and
79 deletions
+99
-79
ranking_tools.h
dlib/svm/ranking_tools.h
+19
-22
ranking_tools_abstract.h
dlib/svm/ranking_tools_abstract.h
+80
-57
No files found.
dlib/svm/ranking_tools.h
View file @
0e67a4e3
...
...
@@ -95,11 +95,17 @@ namespace dlib
{
for
(
unsigned
long
j
=
0
;
j
<
samples
[
i
].
relevant
.
size
();
++
j
)
{
if
(
is_vector
(
samples
[
i
].
relevant
[
j
])
==
false
)
return
false
;
if
(
samples
[
i
].
relevant
[
j
].
size
()
!=
dims
)
return
false
;
}
for
(
unsigned
long
j
=
0
;
j
<
samples
[
i
].
nonrelevant
.
size
();
++
j
)
{
if
(
is_vector
(
samples
[
i
].
nonrelevant
[
j
])
==
false
)
return
false
;
if
(
samples
[
i
].
nonrelevant
[
j
].
size
()
!=
dims
)
return
false
;
}
...
...
@@ -109,6 +115,8 @@ namespace dlib
return
true
;
}
// ----------------------------------------------------------------------------------------
template
<
typename
T
>
...
...
@@ -143,18 +151,6 @@ namespace dlib
std
::
vector
<
unsigned
long
>&
x_count
,
std
::
vector
<
unsigned
long
>&
y_count
)
/*!
ensures
- This function counts how many times we see a y value greater than or equal to
x value. This is done efficiently in O(n*log(n)) time via the use of quick
sort.
- #x_count.size() == x.size()
- #y_count.size() == y.size()
- for all valid i:
- #x_count[i] == how many times a value in y was >= x[i].
- for all valid j:
- #y_count[j] == how many times a value in x was <= y[j].
!*/
{
x_count
.
assign
(
x
.
size
(),
0
);
y_count
.
assign
(
y
.
size
(),
0
);
...
...
@@ -179,7 +175,7 @@ namespace dlib
for
(
i
=
0
,
j
=
0
;
i
<
x_count
.
size
();
++
i
)
{
// Skip past y values that are in the correct order with respect to xsort[i].
while
(
j
<
ysort
.
size
()
&&
xsort
[
i
].
first
>
ysort
[
j
].
first
)
while
(
j
<
ysort
.
size
()
&&
ysort
[
j
].
first
<
xsort
[
i
].
first
)
++
j
;
x_count
[
xsort
[
i
].
second
]
=
ysort
.
size
()
-
j
;
...
...
@@ -190,7 +186,7 @@ namespace dlib
for
(
i
=
0
,
j
=
0
;
j
<
y_count
.
size
();
++
j
)
{
// Skip past x values that are in the incorrect order with respect to ysort[j].
while
(
i
<
xsort
.
size
()
&&
xsort
[
i
].
first
<=
ysort
[
j
].
first
)
while
(
i
<
xsort
.
size
()
&&
!
(
ysort
[
j
].
first
<
xsort
[
i
].
first
)
)
++
i
;
y_count
[
ysort
[
j
].
second
]
=
i
;
...
...
@@ -207,11 +203,15 @@ namespace dlib
const
ranking_function
&
funct
,
const
std
::
vector
<
ranking_pair
<
T
>
>&
samples
)
/*!
ensures
- returns the fraction of ranking pairs predicted correctly.
!*/
{
// make sure requires clause is not broken
DLIB_ASSERT
(
is_ranking_problem
(
samples
),
"
\t
double test_ranking_function()"
<<
"
\n\t
invalid inputs were given to this function"
<<
"
\n\t
samples.size(): "
<<
samples
.
size
()
<<
"
\n\t
is_ranking_problem(samples): "
<<
is_ranking_problem
(
samples
)
);
unsigned
long
total_pairs
=
0
;
unsigned
long
total_wrong
=
0
;
...
...
@@ -238,9 +238,6 @@ namespace dlib
// Note that we don't need to look at nonrel_counts since it is redundant with
// the information in rel_counts in this case.
total_wrong
+=
sum
(
vector_to_matrix
(
rel_counts
));
// TODO, remove
DLIB_CASSERT
(
sum
(
vector_to_matrix
(
rel_counts
))
==
sum
(
vector_to_matrix
(
nonrel_counts
)),
""
);
}
return
static_cast
<
double
>
(
total_pairs
-
total_wrong
)
/
total_pairs
;
...
...
@@ -259,7 +256,7 @@ namespace dlib
)
{
// make sure requires clause is not broken
DLIB_
C
ASSERT
(
is_ranking_problem
(
samples
)
&&
DLIB_ASSERT
(
is_ranking_problem
(
samples
)
&&
1
<
folds
&&
folds
<=
static_cast
<
long
>
(
samples
.
size
()),
"
\t
double cross_validate_ranking_trainer()"
<<
"
\n\t
invalid inputs were given to this function"
...
...
dlib/svm/ranking_tools_abstract.h
View file @
0e67a4e3
...
...
@@ -20,16 +20,29 @@ namespace dlib
{
/*!
WHAT THIS OBJECT REPRESENTS
This object is used to contain a ranking example. In particular, we say
that a good ranking of T objects is one in which all the elements in
this->relevant are ranked higher than the elements of this->nonrelevant.
Therefore, ranking_pair objects are used to represent training examples for
learning-to-rank tasks.
!*/
ranking_pair
()
{}
/*!
ensures
- #relevant.size() == 0
- #nonrelevant.size() == 0
!*/
ranking_pair
(
const
std
::
vector
<
T
>&
r
,
const
std
::
vector
<
T
>&
nr
)
:
relevant
(
r
),
nonrelevant
(
nr
)
{}
)
:
relevant
(
r
),
nonrelevant
(
nr
)
{}
/*!
ensures
- #relevant == r
- #nonrelevant == nr
!*/
std
::
vector
<
T
>
relevant
;
std
::
vector
<
T
>
nonrelevant
;
...
...
@@ -64,70 +77,60 @@ namespace dlib
>
bool
is_ranking_problem
(
const
std
::
vector
<
ranking_pair
<
T
>
>&
samples
)
{
if
(
samples
.
size
()
==
0
)
return
false
;
for
(
unsigned
long
i
=
0
;
i
<
samples
.
size
();
++
i
)
{
if
(
samples
[
i
].
relevant
.
size
()
==
0
)
return
false
;
if
(
samples
[
i
].
nonrelevant
.
size
()
==
0
)
return
false
;
}
// If these are dense vectors then they must all have the same dimensionality.
if
(
is_matrix
<
T
>::
value
)
{
const
long
dims
=
max_index_plus_one
(
samples
[
0
].
relevant
);
for
(
unsigned
long
i
=
0
;
i
<
samples
.
size
();
++
i
)
{
for
(
unsigned
long
j
=
0
;
j
<
samples
[
i
].
relevant
.
size
();
++
j
)
{
if
(
samples
[
i
].
relevant
[
j
].
size
()
!=
dims
)
return
false
;
}
for
(
unsigned
long
j
=
0
;
j
<
samples
[
i
].
nonrelevant
.
size
();
++
j
)
{
if
(
samples
[
i
].
nonrelevant
[
j
].
size
()
!=
dims
)
return
false
;
}
}
}
return
true
;
}
);
/*!
ensures
- returns true if the data in samples represents a valid learning-to-rank
learning problem. That is, this function returns true if all of the
following are true and false otherwise:
- samples.size() > 0
- for all valid i:
- samples[i].relevant.size() > 0
- samples[i].nonrelevant.size() > 0
- if (is_matrix<T>::value == true) then
- All the elements of samples::nonrelevant and samples::relevant must
represent row or column vectors and they must be the same dimension.
!*/
// ----------------------------------------------------------------------------------------
template
<
typename
T
>
unsigned
long
max_index_plus_one
(
const
ranking_pair
<
T
>&
item
)
{
return
std
::
max
(
max_index_plus_one
(
item
.
relevant
),
max_index_plus_one
(
item
.
nonrelevant
));
}
);
/*!
requires
- T must be a dlib::matrix capable of storing column vectors or T must be a
sparse vector type as defined in dlib/svm/sparse_vector_abstract.h.
ensures
- returns std::max(max_index_plus_one(item.relevant), max_index_plus_one(item.nonrelevant)).
Therefore, this function can be used to find the dimensionality of the
vectors stored in item.
!*/
template
<
typename
T
>
unsigned
long
max_index_plus_one
(
const
std
::
vector
<
ranking_pair
<
T
>
>&
samples
)
{
unsigned
long
dims
=
0
;
for
(
unsigned
long
i
=
0
;
i
<
samples
.
size
();
++
i
)
{
dims
=
std
::
max
(
dims
,
max_index_plus_one
(
samples
[
i
]));
}
return
dims
;
}
);
/*!
requires
- T must be a dlib::matrix capable of storing column vectors or T must be a
sparse vector type as defined in dlib/svm/sparse_vector_abstract.h.
ensures
- returns the maximum of max_index_plus_one(samples[i]) over all valid values
of i. Therefore, this function can be used to find the dimensionality of the
vectors stored in samples
!*/
// ----------------------------------------------------------------------------------------
template
<
typename
T
>
template
<
typename
T
>
void
count_ranking_inversions
(
const
std
::
vector
<
T
>&
x
,
const
std
::
vector
<
T
>&
y
,
...
...
@@ -135,10 +138,13 @@ namespace dlib
std
::
vector
<
unsigned
long
>&
y_count
);
/*!
requires
- T objects must be copyable
- T objects must be comparable via operator<
ensures
- This function counts how many times we see a y value greater than or equal to
x value. This is done efficiently in O(n*log(n)) time via the use of quick
sort.
an x value. This is done efficiently in O(n*log(n)) time via the use of
quick
sort.
- #x_count.size() == x.size()
- #y_count.size() == y.size()
- for all valid i:
...
...
@@ -158,9 +164,18 @@ namespace dlib
const
std
::
vector
<
ranking_pair
<
T
>
>&
samples
);
/*!
requires
- is_ranking_problem(samples) == true
- ranking_function == some kind of decision function object (e.g. decision_function)
ensures
- returns the fraction of ranking pairs predicted correctly.
- TODO
- Tests the given ranking function on the supplied example ranking data and
returns the fraction of ranking pair orderings predicted correctly. This is
a number in the range [0,1] where 0 means everything was incorrectly
predicted while 1 means everything was correctly predicted.
- In particular, this function returns the fraction of times that the following
is true:
- funct(samples[k].relevant[i]) > funct(samples[k].nonrelevant[j])
(for all valid i,j,k)
!*/
// ----------------------------------------------------------------------------------------
...
...
@@ -178,8 +193,16 @@ namespace dlib
requires
- is_ranking_problem(samples) == true
- 1 < folds <= samples.size()
- trainer_type == some kind of ranking trainer object (e.g. svm_rank_trainer)
ensures
- TODO
- Performs k-fold cross validation by using the given trainer to solve the
given ranking problem for the given number of folds. Each fold is tested
using the output of the trainer and the average ranking accuracy from all
folds is returned.
- The accuracy is computed the same way test_ranking_function() computes its
accuracy. Therefore, it is a number in the range [0,1] that represents the
fraction of times a ranking pair's ordering was predicted correctly.
- The number of folds used is given by the folds argument.
!*/
// ----------------------------------------------------------------------------------------
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment