Commit 97a55037 authored by Davis King's avatar Davis King

Changed ranking evaluation functions to return the mean average precision

in addition to just raw ranking accuracy.  This changes their return types
from double to matrix<double,1,2>.
parent 612fe85b
......@@ -11,6 +11,7 @@
#include <utility>
#include <algorithm>
#include "sparse_vector.h"
#include "../statistics.h"
namespace dlib
{
......@@ -219,7 +220,7 @@ namespace dlib
typename ranking_function,
typename T
>
double test_ranking_function (
matrix<double,1,2> test_ranking_function (
const ranking_function& funct,
const std::vector<ranking_pair<T> >& samples
)
......@@ -240,16 +241,36 @@ namespace dlib
std::vector<unsigned long> rel_counts;
std::vector<unsigned long> nonrel_counts;
running_stats<double> rs;
std::vector<std::pair<double,bool> > total_scores;
std::vector<bool> total_ranking;
for (unsigned long i = 0; i < samples.size(); ++i)
{
rel_scores.resize(samples[i].relevant.size());
nonrel_scores.resize(samples[i].nonrelevant.size());
total_scores.clear();
for (unsigned long k = 0; k < rel_scores.size(); ++k)
{
rel_scores[k] = funct(samples[i].relevant[k]);
total_scores.push_back(std::make_pair(rel_scores[k], true));
}
for (unsigned long k = 0; k < nonrel_scores.size(); ++k)
{
nonrel_scores[k] = funct(samples[i].nonrelevant[k]);
total_scores.push_back(std::make_pair(nonrel_scores[k], false));
}
// Now compute the average precision for this sample. We need to sort the
// results and the back them into total_ranking.
std::sort(total_scores.rbegin(), total_scores.rend());
total_ranking.clear();
for (unsigned long i = 0; i < total_scores.size(); ++i)
total_ranking.push_back(total_scores[i].second);
rs.add(average_precision(total_ranking));
count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts);
......@@ -260,7 +281,11 @@ namespace dlib
total_wrong += sum(mat(rel_counts));
}
return static_cast<double>(total_pairs - total_wrong) / total_pairs;
const double rank_swaps = static_cast<double>(total_pairs - total_wrong) / total_pairs;
const double mean_average_precision = rs.mean();
matrix<double,1,2> res;
res = rank_swaps, mean_average_precision;
return res;
}
// ----------------------------------------------------------------------------------------
......@@ -269,7 +294,7 @@ namespace dlib
typename ranking_function,
typename T
>
double test_ranking_function (
matrix<double,1,2> test_ranking_function (
const ranking_function& funct,
const ranking_pair<T>& sample
)
......@@ -283,7 +308,7 @@ namespace dlib
typename trainer_type,
typename T
>
double cross_validate_ranking_trainer (
matrix<double,1,2> cross_validate_ranking_trainer (
const trainer_type& trainer,
const std::vector<ranking_pair<T> >& samples,
const long folds
......@@ -317,6 +342,9 @@ namespace dlib
std::vector<unsigned long> rel_counts;
std::vector<unsigned long> nonrel_counts;
running_stats<double> rs;
std::vector<std::pair<double,bool> > total_scores;
std::vector<bool> total_ranking;
for (long i = 0; i < folds; ++i)
{
......@@ -347,11 +375,28 @@ namespace dlib
rel_scores.resize(samples_test[i].relevant.size());
nonrel_scores.resize(samples_test[i].nonrelevant.size());
total_scores.clear();
for (unsigned long k = 0; k < rel_scores.size(); ++k)
{
rel_scores[k] = df(samples_test[i].relevant[k]);
total_scores.push_back(std::make_pair(rel_scores[k], true));
}
for (unsigned long k = 0; k < nonrel_scores.size(); ++k)
{
nonrel_scores[k] = df(samples_test[i].nonrelevant[k]);
total_scores.push_back(std::make_pair(nonrel_scores[k], false));
}
// Now compute the average precision for this sample. We need to sort the
// results and the back them into total_ranking.
std::sort(total_scores.rbegin(), total_scores.rend());
total_ranking.clear();
for (unsigned long i = 0; i < total_scores.size(); ++i)
total_ranking.push_back(total_scores[i].second);
rs.add(average_precision(total_ranking));
count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts);
......@@ -364,7 +409,11 @@ namespace dlib
} // for (long i = 0; i < folds; ++i)
return static_cast<double>(total_pairs - total_wrong) / total_pairs;
const double rank_swaps = static_cast<double>(total_pairs - total_wrong) / total_pairs;
const double mean_average_precision = rs.mean();
matrix<double,1,2> res;
res = rank_swaps, mean_average_precision;
return res;
}
// ----------------------------------------------------------------------------------------
......
......@@ -159,7 +159,7 @@ namespace dlib
typename ranking_function,
typename T
>
double test_ranking_function (
matrix<double,1,2> test_ranking_function (
const ranking_function& funct,
const std::vector<ranking_pair<T> >& samples
);
......@@ -171,11 +171,17 @@ namespace dlib
- Tests the given ranking function on the supplied example ranking data and
returns the fraction of ranking pair orderings predicted correctly. This is
a number in the range [0,1] where 0 means everything was incorrectly
predicted while 1 means everything was correctly predicted.
- In particular, this function returns the fraction of times that the following
is true:
- funct(samples[k].relevant[i]) > funct(samples[k].nonrelevant[j])
(for all valid i,j,k)
predicted while 1 means everything was correctly predicted. This function
also returns the mean average precision.
- In particular, this function returns a matrix M summarizing the results.
Specifically, it returns an M such that:
- M(0) == the fraction of times that the following is true:
- funct(samples[k].relevant[i]) > funct(samples[k].nonrelevant[j])
(for all valid i,j,k)
- M(1) == the mean average precision of the rankings induced by funct.
(Mean average precision is a number in the range 0 to 1. Moreover, a
mean average precision of 1 means everything was correctly predicted
while smaller values indicate worse rankings.)
!*/
// ----------------------------------------------------------------------------------------
......@@ -184,7 +190,7 @@ namespace dlib
typename ranking_function,
typename T
>
double test_ranking_function (
matrix<double,1,2> test_ranking_function (
const ranking_function& funct,
const ranking_pair<T>& sample
);
......@@ -206,7 +212,7 @@ namespace dlib
typename trainer_type,
typename T
>
double cross_validate_ranking_trainer (
matrix<double,1,2> cross_validate_ranking_trainer (
const trainer_type& trainer,
const std::vector<ranking_pair<T> >& samples,
const long folds
......@@ -219,11 +225,15 @@ namespace dlib
ensures
- Performs k-fold cross validation by using the given trainer to solve the
given ranking problem for the given number of folds. Each fold is tested
using the output of the trainer and the average ranking accuracy from all
folds is returned.
using the output of the trainer and the average ranking accuracy as well as
the mean average precision over the number of folds is returned.
- The accuracy is computed the same way test_ranking_function() computes its
accuracy. Therefore, it is a number in the range [0,1] that represents the
fraction of times a ranking pair's ordering was predicted correctly.
fraction of times a ranking pair's ordering was predicted correctly. Similarly,
the mean average precision is computed identically to test_ranking_function().
In particular, this means that this function returns a matrix M such that:
- M(0) == the ranking accuracy
- M(1) == the mean average precision
- The number of folds used is given by the folds argument.
!*/
......
......@@ -108,9 +108,11 @@ namespace
decision_function<kernel_type> df = trainer.train(samples);
dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples);
DLIB_TEST(std::abs(test_ranking_function(df, samples) - 1.0) < 1e-14);
matrix<double,1,2> res;
res = 1,1;
DLIB_TEST(equal(test_ranking_function(df, samples), res));
DLIB_TEST(std::abs(test_ranking_function(trainer.train(samples[1]), samples) - 1.0) < 1e-14);
DLIB_TEST(equal(test_ranking_function(trainer.train(samples[1]), samples), res));
trainer.set_epsilon(1e-13);
df = trainer.train(samples);
......@@ -121,10 +123,10 @@ namespace
DLIB_TEST(length(truew - df.basis_vectors(0)) < 1e-10);
dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples);
DLIB_TEST(std::abs(test_ranking_function(df, samples) - 1.0) < 1e-14);
DLIB_TEST(equal(test_ranking_function(df, samples), res));
dlog << LINFO << "cv-accuracy: "<< cross_validate_ranking_trainer(trainer, samples,2);
DLIB_TEST(std::abs(cross_validate_ranking_trainer(trainer, samples,2) - 0.7777777778) < 0.0001);
DLIB_TEST(std::abs(cross_validate_ranking_trainer(trainer, samples,2)(0) - 0.7777777778) < 0.0001);
trainer.set_learns_nonnegative_weights(true);
df = trainer.train(samples);
......@@ -132,7 +134,7 @@ namespace
dlog << LINFO << df.basis_vectors(0);
DLIB_TEST(length(truew - df.basis_vectors(0)) < 1e-10);
dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples);
DLIB_TEST(std::abs(test_ranking_function(df, samples) - 1.0) < 1e-14);
DLIB_TEST(equal(test_ranking_function(df, samples), res));
samples.clear();
......@@ -141,7 +143,7 @@ namespace
samples.push_back(p);
samples.push_back(p);
dlog << LINFO << "cv-accuracy: "<< cross_validate_ranking_trainer(trainer, samples,4);
DLIB_TEST(std::abs(cross_validate_ranking_trainer(trainer, samples,4) - 1) < 1e-12);
DLIB_TEST(equal(cross_validate_ranking_trainer(trainer, samples,4) , res));
}
// ----------------------------------------------------------------------------------------
......@@ -178,10 +180,13 @@ namespace
decision_function<kernel_type> df = trainer.train(samples);
matrix<double,1,2> res;
res = 1,1;
dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples);
DLIB_TEST(std::abs(test_ranking_function(df, samples) - 1.0) < 1e-14);
DLIB_TEST(equal(test_ranking_function(df, samples), res));
DLIB_TEST(std::abs(test_ranking_function(trainer.train(samples[1]), samples) - 1.0) < 1e-14);
DLIB_TEST(equal(test_ranking_function(trainer.train(samples[1]), samples), res));
trainer.set_epsilon(1e-13);
df = trainer.train(samples);
......@@ -195,10 +200,10 @@ namespace
DLIB_TEST(length(subtract(truew , df.basis_vectors(0))) < 1e-10);
dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples);
DLIB_TEST(std::abs(test_ranking_function(df, samples) - 1.0) < 1e-14);
DLIB_TEST(equal(test_ranking_function(df, samples), res));
dlog << LINFO << "cv-accuracy: "<< cross_validate_ranking_trainer(trainer, samples,2);
DLIB_TEST(std::abs(cross_validate_ranking_trainer(trainer, samples,2) - 0.7777777778) < 0.0001);
DLIB_TEST(std::abs(cross_validate_ranking_trainer(trainer, samples,2)(0) - 0.7777777778) < 0.0001);
trainer.set_learns_nonnegative_weights(true);
df = trainer.train(samples);
......@@ -209,7 +214,7 @@ namespace
dlog << LINFO << sparse_to_dense(df.basis_vectors(0));
DLIB_TEST(length(subtract(truew , df.basis_vectors(0))) < 1e-10);
dlog << LINFO << "accuracy: "<< test_ranking_function(df, samples);
DLIB_TEST(std::abs(test_ranking_function(df, samples) - 1.0) < 1e-14);
DLIB_TEST(equal(test_ranking_function(df, samples), res));
samples.clear();
......@@ -218,7 +223,7 @@ namespace
samples.push_back(p);
samples.push_back(p);
dlog << LINFO << "cv-accuracy: "<< cross_validate_ranking_trainer(trainer, samples,4);
DLIB_TEST(std::abs(cross_validate_ranking_trainer(trainer, samples,4) - 1) < 1e-12);
DLIB_TEST(equal(cross_validate_ranking_trainer(trainer, samples,4) , res) );
}
// ----------------------------------------------------------------------------------------
......@@ -303,18 +308,20 @@ namespace
decision_function<kernel_type> df;
df = trainer.train(pair);
matrix<double,1,2> res;
res = 1,1;
dlog << LINFO << "weights: "<< trans(df.basis_vectors(0));
const double acc1 = test_ranking_function(df, pair);
const matrix<double,1,2> acc1 = test_ranking_function(df, pair);
dlog << LINFO << "ranking accuracy: " << acc1;
DLIB_TEST(std::abs(acc1 - 1) == 0);
DLIB_TEST(equal(acc1,res));
simple_rank_trainer<kernel_type,use_dcd_trainer> strainer;
decision_function<kernel_type> df2;
df2 = strainer.train(pair);
dlog << LINFO << "weights: "<< trans(df2.basis_vectors(0));
const double acc2 = test_ranking_function(df2, pair);
const matrix<double,1,2> acc2 = test_ranking_function(df2, pair);
dlog << LINFO << "ranking accuracy: " << acc2;
DLIB_TEST(std::abs(acc2 - 1) == 0);
DLIB_TEST(equal(acc2,res));
dlog << LINFO << "w error: " << max(abs(df.basis_vectors(0) - df2.basis_vectors(0)));
dlog << LINFO << "b error: " << abs(df.b - df2.b);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment