Commit c9ed8aa6 authored by Davis King's avatar Davis King

Changed object detector testing functions to output average precision

instead of mean average precision.
parent 3bfc3612
...@@ -20,9 +20,10 @@ namespace dlib ...@@ -20,9 +20,10 @@ namespace dlib
{ {
inline unsigned long number_of_truth_hits ( inline unsigned long number_of_truth_hits (
const std::vector<full_object_detection>& truth_boxes, const std::vector<full_object_detection>& truth_boxes,
const std::vector<rectangle>& boxes, const std::vector<std::pair<double,rectangle> >& boxes,
const double overlap_eps, const double overlap_eps,
double& ap std::vector<std::pair<double,bool> >& all_dets,
unsigned long& missing_detections
) )
/*! /*!
requires requires
...@@ -34,17 +35,14 @@ namespace dlib ...@@ -34,17 +35,14 @@ namespace dlib
A.intersect(B).area()/(A+B).area() A.intersect(B).area()/(A+B).area()
- No element of boxes is allowed to account for more than one element of truth_boxes. - No element of boxes is allowed to account for more than one element of truth_boxes.
- The returned number is in the range [0,truth_boxes.size()] - The returned number is in the range [0,truth_boxes.size()]
- ap == the average precision of the given ordering in boxes relative - Adds the score for each box from boxes into all_dets and labels each with
to truth_boxes. a bool indicating if it hit a truth box. Also adds the number of truth
boxes which didn't have any hits into missing_detections.
!*/ !*/
{ {
if (boxes.size() == 0) if (boxes.size() == 0)
{ {
if (truth_boxes.size() == 0) missing_detections += truth_boxes.size();
ap = 1;
else
ap = 0;
return 0; return 0;
} }
...@@ -54,12 +52,13 @@ namespace dlib ...@@ -54,12 +52,13 @@ namespace dlib
{ {
unsigned long best_idx = 0; unsigned long best_idx = 0;
double best_overlap = 0; double best_overlap = 0;
// Find the best box that hits truth_boxes[i]
for (unsigned long j = 0; j < boxes.size(); ++j) for (unsigned long j = 0; j < boxes.size(); ++j)
{ {
if (used[j]) if (used[j])
continue; continue;
const double overlap = truth_boxes[i].get_rect().intersect(boxes[j]).area() / (double)(truth_boxes[i].get_rect()+boxes[j]).area(); const double overlap = truth_boxes[i].get_rect().intersect(boxes[j].second).area() / (double)(truth_boxes[i].get_rect()+boxes[j].second).area();
if (overlap > best_overlap) if (overlap > best_overlap)
{ {
best_overlap = overlap; best_overlap = overlap;
...@@ -67,14 +66,22 @@ namespace dlib ...@@ -67,14 +66,22 @@ namespace dlib
} }
} }
if (best_overlap > overlap_eps && used[best_idx] == false) // if there was any box that hit truth_boxes[i]
if (best_overlap > overlap_eps)
{ {
used[best_idx] = true; used[best_idx] = true;
++count; ++count;
} }
else
{
++missing_detections;
}
} }
ap = average_precision(used, truth_boxes.size()-count); for (unsigned long i = 0; i < boxes.size(); ++i)
{
all_dets.push_back(std::make_pair(boxes[i].first, used[i]));
}
return count; return count;
} }
...@@ -109,19 +116,22 @@ namespace dlib ...@@ -109,19 +116,22 @@ namespace dlib
double correct_hits = 0; double correct_hits = 0;
double total_hits = 0; double total_hits = 0;
double total_true_targets = 0; double total_true_targets = 0;
running_stats<double> map;
std::vector<std::pair<double,bool> > all_dets;
unsigned long missing_detections = 0;
for (unsigned long i = 0; i < images.size(); ++i) for (unsigned long i = 0; i < images.size(); ++i)
{ {
const std::vector<rectangle>& hits = detector(images[i], adjust_threshold); std::vector<std::pair<double,rectangle> > hits;
detector(images[i], hits, adjust_threshold);
double ap;
total_hits += hits.size(); total_hits += hits.size();
correct_hits += impl::number_of_truth_hits(truth_dets[i], hits, overlap_eps, ap); correct_hits += impl::number_of_truth_hits(truth_dets[i], hits, overlap_eps, all_dets, missing_detections);
total_true_targets += truth_dets[i].size(); total_true_targets += truth_dets[i].size();
map.add(ap);
} }
std::sort(all_dets.rbegin(), all_dets.rend());
double precision, recall; double precision, recall;
...@@ -136,7 +146,7 @@ namespace dlib ...@@ -136,7 +146,7 @@ namespace dlib
recall = correct_hits / total_true_targets; recall = correct_hits / total_true_targets;
matrix<double, 1, 3> res; matrix<double, 1, 3> res;
res = precision, recall, map.mean(); res = precision, recall, average_precision(all_dets, missing_detections);
return res; return res;
} }
...@@ -242,7 +252,8 @@ namespace dlib ...@@ -242,7 +252,8 @@ namespace dlib
const long test_size = images.size()/folds; const long test_size = images.size()/folds;
running_stats<double> map; std::vector<std::pair<double,bool> > all_dets;
unsigned long missing_detections = 0;
unsigned long test_idx = 0; unsigned long test_idx = 0;
for (long iter = 0; iter < folds; ++iter) for (long iter = 0; iter < folds; ++iter)
{ {
...@@ -266,17 +277,17 @@ namespace dlib ...@@ -266,17 +277,17 @@ namespace dlib
typename trainer_type::trained_function_type detector = trainer.train(array_subset, training_rects); typename trainer_type::trained_function_type detector = trainer.train(array_subset, training_rects);
for (unsigned long i = 0; i < test_idx_set.size(); ++i) for (unsigned long i = 0; i < test_idx_set.size(); ++i)
{ {
const std::vector<rectangle>& hits = detector(images[test_idx_set[i]], adjust_threshold); std::vector<std::pair<double,rectangle> > hits;
detector(images[test_idx_set[i]], hits, adjust_threshold);
double ap;
total_hits += hits.size(); total_hits += hits.size();
correct_hits += impl::number_of_truth_hits(truth_dets[test_idx_set[i]], hits, overlap_eps, ap); correct_hits += impl::number_of_truth_hits(truth_dets[test_idx_set[i]], hits, overlap_eps, all_dets, missing_detections);
total_true_targets += truth_dets[test_idx_set[i]].size(); total_true_targets += truth_dets[test_idx_set[i]].size();
map.add(ap);
} }
} }
std::sort(all_dets.rbegin(), all_dets.rend());
double precision, recall; double precision, recall;
...@@ -292,7 +303,7 @@ namespace dlib ...@@ -292,7 +303,7 @@ namespace dlib
recall = correct_hits / total_true_targets; recall = correct_hits / total_true_targets;
matrix<double, 1, 3> res; matrix<double, 1, 3> res;
res = precision, recall, map.mean(); res = precision, recall, average_precision(all_dets, missing_detections);
return res; return res;
} }
......
...@@ -34,9 +34,8 @@ namespace dlib ...@@ -34,9 +34,8 @@ namespace dlib
and it must contain objects which can be accepted by detector(). and it must contain objects which can be accepted by detector().
ensures ensures
- Tests the given detector against the supplied object detection problem and - Tests the given detector against the supplied object detection problem and
returns the precision, recall, and mean average precision. Note that the returns the precision, recall, and average precision. Note that the task is
task is to predict, for each images[i], the set of object locations given by to predict, for each images[i], the set of object locations given by truth_dets[i].
truth_dets[i].
- In particular, returns a matrix M such that: - In particular, returns a matrix M such that:
- M(0) == the precision of the detector object. This is a number - M(0) == the precision of the detector object. This is a number
in the range [0,1] which measures the fraction of detector outputs in the range [0,1] which measures the fraction of detector outputs
...@@ -48,12 +47,12 @@ namespace dlib ...@@ -48,12 +47,12 @@ namespace dlib
detector. A value of 1 means the detector found all the targets detector. A value of 1 means the detector found all the targets
in truth_dets while a value of 0 means the detector didn't locate in truth_dets while a value of 0 means the detector didn't locate
any of the targets. any of the targets.
- M(2) == the mean average precision of the detector object. This is a - M(2) == the average precision of the detector object. This is a number
number in the range [0,1] which measures the overall quality of the in the range [0,1] which measures the overall quality of the detector.
detector. We do this by taking all the detections output by the detector We compute this by taking all the detections output by the detector and
and ordering them by their detection score. Then we use the ordering them in descending order of their detection scores. Then we use
average_precision() routine to score the ranked listing. Finally we set the average_precision() routine to score the ranked listing and store the
M(2) to the mean value over all test images. output into M(2).
- The rule for deciding if a detector output, D, matches a truth rectangle, - The rule for deciding if a detector output, D, matches a truth rectangle,
T, is the following: T, is the following:
T and R match if and only if: T.intersect(R).area()/(T+R).area() > overlap_eps T and R match if and only if: T.intersect(R).area()/(T+R).area() > overlap_eps
...@@ -63,7 +62,7 @@ namespace dlib ...@@ -63,7 +62,7 @@ namespace dlib
output detections. It can be useful, for example, to lower the detection output detections. It can be useful, for example, to lower the detection
threshold because it results in more detections being output by the threshold because it results in more detections being output by the
detector, and therefore provides more information in the ranking, detector, and therefore provides more information in the ranking,
possibly raising the mean average precision. possibly raising the average precision.
!*/ !*/
template < template <
...@@ -114,7 +113,7 @@ namespace dlib ...@@ -114,7 +113,7 @@ namespace dlib
- Performs k-fold cross-validation by using the given trainer to solve an - Performs k-fold cross-validation by using the given trainer to solve an
object detection problem for the given number of folds. Each fold is tested object detection problem for the given number of folds. Each fold is tested
using the output of the trainer and a matrix summarizing the results is using the output of the trainer and a matrix summarizing the results is
returned. The matrix contains the precision, recall, and mean average returned. The matrix contains the precision, recall, and average
precision of the trained detectors and is defined identically to the precision of the trained detectors and is defined identically to the
test_object_detection_function() routine defined at the top of this file. test_object_detection_function() routine defined at the top of this file.
!*/ !*/
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment