Commit c9ed8aa6 authored by Davis King's avatar Davis King

Changed object detector testing functions to output average precision

instead of mean average precision.
parent 3bfc3612
......@@ -20,9 +20,10 @@ namespace dlib
inline unsigned long number_of_truth_hits (
const std::vector<full_object_detection>& truth_boxes,
const std::vector<rectangle>& boxes,
const std::vector<std::pair<double,rectangle> >& boxes,
const double overlap_eps,
double& ap
std::vector<std::pair<double,bool> >& all_dets,
unsigned long& missing_detections
......@@ -34,17 +35,14 @@ namespace dlib
- No element of boxes is allowed to account for more than one element of truth_boxes.
- The returned number is in the range [0,truth_boxes.size()]
- ap == the average precision of the given ordering in boxes relative
to truth_boxes.
- Adds the score for each box from boxes into all_dets and labels each with
a bool indicating if it hit a truth box. Also adds the number of truth
boxes which didn't have any hits into missing_detections.
if (boxes.size() == 0)
if (truth_boxes.size() == 0)
ap = 1;
ap = 0;
missing_detections += truth_boxes.size();
return 0;
......@@ -54,12 +52,13 @@ namespace dlib
unsigned long best_idx = 0;
double best_overlap = 0;
// Find the best box that hits truth_boxes[i]
for (unsigned long j = 0; j < boxes.size(); ++j)
if (used[j])
const double overlap = truth_boxes[i].get_rect().intersect(boxes[j]).area() / (double)(truth_boxes[i].get_rect()+boxes[j]).area();
const double overlap = truth_boxes[i].get_rect().intersect(boxes[j].second).area() / (double)(truth_boxes[i].get_rect()+boxes[j].second).area();
if (overlap > best_overlap)
best_overlap = overlap;
......@@ -67,14 +66,22 @@ namespace dlib
if (best_overlap > overlap_eps && used[best_idx] == false)
// if there was any box that hit truth_boxes[i]
if (best_overlap > overlap_eps)
used[best_idx] = true;
ap = average_precision(used, truth_boxes.size()-count);
for (unsigned long i = 0; i < boxes.size(); ++i)
all_dets.push_back(std::make_pair(boxes[i].first, used[i]));
return count;
......@@ -109,19 +116,22 @@ namespace dlib
double correct_hits = 0;
double total_hits = 0;
double total_true_targets = 0;
running_stats<double> map;
std::vector<std::pair<double,bool> > all_dets;
unsigned long missing_detections = 0;
for (unsigned long i = 0; i < images.size(); ++i)
const std::vector<rectangle>& hits = detector(images[i], adjust_threshold);
std::vector<std::pair<double,rectangle> > hits;
detector(images[i], hits, adjust_threshold);
double ap;
total_hits += hits.size();
correct_hits += impl::number_of_truth_hits(truth_dets[i], hits, overlap_eps, ap);
correct_hits += impl::number_of_truth_hits(truth_dets[i], hits, overlap_eps, all_dets, missing_detections);
total_true_targets += truth_dets[i].size();
std::sort(all_dets.rbegin(), all_dets.rend());
double precision, recall;
......@@ -136,7 +146,7 @@ namespace dlib
recall = correct_hits / total_true_targets;
matrix<double, 1, 3> res;
res = precision, recall, map.mean();
res = precision, recall, average_precision(all_dets, missing_detections);
return res;
......@@ -242,7 +252,8 @@ namespace dlib
const long test_size = images.size()/folds;
running_stats<double> map;
std::vector<std::pair<double,bool> > all_dets;
unsigned long missing_detections = 0;
unsigned long test_idx = 0;
for (long iter = 0; iter < folds; ++iter)
......@@ -266,17 +277,17 @@ namespace dlib
typename trainer_type::trained_function_type detector = trainer.train(array_subset, training_rects);
for (unsigned long i = 0; i < test_idx_set.size(); ++i)
const std::vector<rectangle>& hits = detector(images[test_idx_set[i]], adjust_threshold);
std::vector<std::pair<double,rectangle> > hits;
detector(images[test_idx_set[i]], hits, adjust_threshold);
double ap;
total_hits += hits.size();
correct_hits += impl::number_of_truth_hits(truth_dets[test_idx_set[i]], hits, overlap_eps, ap);
correct_hits += impl::number_of_truth_hits(truth_dets[test_idx_set[i]], hits, overlap_eps, all_dets, missing_detections);
total_true_targets += truth_dets[test_idx_set[i]].size();
std::sort(all_dets.rbegin(), all_dets.rend());
double precision, recall;
......@@ -292,7 +303,7 @@ namespace dlib
recall = correct_hits / total_true_targets;
matrix<double, 1, 3> res;
res = precision, recall, map.mean();
res = precision, recall, average_precision(all_dets, missing_detections);
return res;
......@@ -34,9 +34,8 @@ namespace dlib
and it must contain objects which can be accepted by detector().
- Tests the given detector against the supplied object detection problem and
returns the precision, recall, and mean average precision. Note that the
task is to predict, for each images[i], the set of object locations given by
returns the precision, recall, and average precision. Note that the task is
to predict, for each images[i], the set of object locations given by truth_dets[i].
- In particular, returns a matrix M such that:
- M(0) == the precision of the detector object. This is a number
in the range [0,1] which measures the fraction of detector outputs
......@@ -48,12 +47,12 @@ namespace dlib
detector. A value of 1 means the detector found all the targets
in truth_dets while a value of 0 means the detector didn't locate
any of the targets.
- M(2) == the mean average precision of the detector object. This is a
number in the range [0,1] which measures the overall quality of the
detector. We do this by taking all the detections output by the detector
and ordering them by their detection score. Then we use the
average_precision() routine to score the ranked listing. Finally we set
M(2) to the mean value over all test images.
- M(2) == the average precision of the detector object. This is a number
in the range [0,1] which measures the overall quality of the detector.
We compute this by taking all the detections output by the detector and
ordering them in descending order of their detection scores. Then we use
the average_precision() routine to score the ranked listing and store the
output into M(2).
- The rule for deciding if a detector output, D, matches a truth rectangle,
T, is the following:
T and R match if and only if: T.intersect(R).area()/(T+R).area() > overlap_eps
......@@ -63,7 +62,7 @@ namespace dlib
output detections. It can be useful, for example, to lower the detection
threshold because it results in more detections being output by the
detector, and therefore provides more information in the ranking,
possibly raising the mean average precision.
possibly raising the average precision.
template <
......@@ -114,7 +113,7 @@ namespace dlib
- Performs k-fold cross-validation by using the given trainer to solve an
object detection problem for the given number of folds. Each fold is tested
using the output of the trainer and a matrix summarizing the results is
returned. The matrix contains the precision, recall, and mean average
returned. The matrix contains the precision, recall, and average
precision of the trained detectors and is defined identically to the
test_object_detection_function() routine defined at the top of this file.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment