Commit bbefbc17 authored by Davis King's avatar Davis King

Fixed the ranking test functions so they correctly compute the MAP values

for ranking functions which output constant values.
parent 0c0e744d
......@@ -216,6 +216,24 @@ namespace dlib
// ----------------------------------------------------------------------------------------
namespace impl
{
inline bool compare_first_reverse_second (
const std::pair<double,bool>& a,
const std::pair<double,bool>& b
)
{
if (a.first < b.first)
return true;
else if (a.first > b.first)
return false;
else if (a.second == true)
return true;
else
return false;
}
}
template <
typename ranking_function,
typename T
......@@ -264,8 +282,11 @@ namespace dlib
}
// Now compute the average precision for this sample. We need to sort the
// results and the back them into total_ranking.
std::sort(total_scores.rbegin(), total_scores.rend());
// results and the back them into total_ranking. Note that we sort them so
// that, if you get a block of ranking values that are all equal, the elements
// marked as true will come last. This prevents a ranking from outputting a
// constant value for everything and still getting a good MAP score.
std::sort(total_scores.rbegin(), total_scores.rend(), impl::compare_first_reverse_second);
total_ranking.clear();
for (unsigned long i = 0; i < total_scores.size(); ++i)
total_ranking.push_back(total_scores[i].second);
......@@ -390,8 +411,11 @@ namespace dlib
}
// Now compute the average precision for this sample. We need to sort the
// results and the back them into total_ranking.
std::sort(total_scores.rbegin(), total_scores.rend());
// results and the back them into total_ranking. Note that we sort them so
// that, if you get a block of ranking values that are all equal, the elements
// marked as true will come last. This prevents a ranking from outputting a
// constant value for everything and still getting a good MAP score.
std::sort(total_scores.rbegin(), total_scores.rend(), impl::compare_first_reverse_second);
total_ranking.clear();
for (unsigned long i = 0; i < total_scores.size(); ++i)
total_ranking.push_back(total_scores[i].second);
......
......@@ -181,7 +181,8 @@ namespace dlib
- M(1) == the mean average precision of the rankings induced by funct.
(Mean average precision is a number in the range 0 to 1. Moreover, a
mean average precision of 1 means everything was correctly predicted
while smaller values indicate worse rankings.)
while smaller values indicate worse rankings. See the documentation
for average_precision() for details of its computation.)
!*/
// ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment