Commit a73e7659 authored by Davis King's avatar Davis King

Made the object detector validation functions also output the mean average

precision measure.
parent 230dc754
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "svm.h" #include "svm.h"
#include "../geometry.h" #include "../geometry.h"
#include "../image_processing/full_object_detection.h" #include "../image_processing/full_object_detection.h"
#include "../statistics.h"
namespace dlib namespace dlib
{ {
...@@ -20,7 +21,8 @@ namespace dlib ...@@ -20,7 +21,8 @@ namespace dlib
inline unsigned long number_of_truth_hits ( inline unsigned long number_of_truth_hits (
const std::vector<full_object_detection>& truth_boxes, const std::vector<full_object_detection>& truth_boxes,
const std::vector<rectangle>& boxes, const std::vector<rectangle>& boxes,
const double overlap_eps const double overlap_eps,
double& ap
) )
/*! /*!
requires requires
...@@ -32,10 +34,19 @@ namespace dlib ...@@ -32,10 +34,19 @@ namespace dlib
A.intersect(B).area()/(A+B).area() A.intersect(B).area()/(A+B).area()
- No element of boxes is allowed to account for more than one element of truth_boxes. - No element of boxes is allowed to account for more than one element of truth_boxes.
- The returned number is in the range [0,truth_boxes.size()] - The returned number is in the range [0,truth_boxes.size()]
- ap == the average precision of the given ordering in boxes relative
to truth_boxes.
!*/ !*/
{ {
if (boxes.size() == 0) if (boxes.size() == 0)
{
if (truth_boxes.size() == 0)
ap = 1;
else
ap = 0;
return 0; return 0;
}
unsigned long count = 0; unsigned long count = 0;
std::vector<bool> used(boxes.size(),false); std::vector<bool> used(boxes.size(),false);
...@@ -63,8 +74,20 @@ namespace dlib ...@@ -63,8 +74,20 @@ namespace dlib
} }
} }
ap = average_precision(used, truth_boxes.size()-count);
return count; return count;
} }
inline unsigned long number_of_truth_hits (
const std::vector<full_object_detection>& truth_boxes,
const std::vector<rectangle>& boxes,
const double overlap_eps
)
{
double ap;
return number_of_truth_hits(truth_boxes, boxes, overlap_eps, ap);
}
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -73,7 +96,7 @@ namespace dlib ...@@ -73,7 +96,7 @@ namespace dlib
typename object_detector_type, typename object_detector_type,
typename image_array_type typename image_array_type
> >
const matrix<double,1,2> test_object_detection_function ( const matrix<double,1,3> test_object_detection_function (
object_detector_type& detector, object_detector_type& detector,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_dets, const std::vector<std::vector<full_object_detection> >& truth_dets,
...@@ -94,14 +117,30 @@ namespace dlib ...@@ -94,14 +117,30 @@ namespace dlib
double correct_hits = 0; double correct_hits = 0;
double total_hits = 0; double total_hits = 0;
double total_true_targets = 0; double total_true_targets = 0;
running_stats<double> map;
for (unsigned long i = 0; i < images.size(); ++i) for (unsigned long i = 0; i < images.size(); ++i)
{ {
const std::vector<rectangle>& hits = detector(images[i]); std::vector<std::pair<double,rectangle> > all_dets;
detector(images[i], all_dets, -std::numeric_limits<double>::infinity());
std::vector<rectangle> hits;
for (unsigned long k = 0; k < all_dets.size(); ++k)
{
if (all_dets[k].first >= 0)
hits.push_back(all_dets[k].second);
}
total_hits += hits.size(); total_hits += hits.size();
correct_hits += impl::number_of_truth_hits(truth_dets[i], hits, overlap_eps); correct_hits += impl::number_of_truth_hits(truth_dets[i], hits, overlap_eps);
total_true_targets += truth_dets[i].size(); total_true_targets += truth_dets[i].size();
// now get the average precision
hits.clear();
for (unsigned long k = 0; k < all_dets.size(); ++k)
hits.push_back(all_dets[k].second);
double ap;
impl::number_of_truth_hits(truth_dets[i], hits, overlap_eps, ap);
map.add(ap);
} }
...@@ -117,8 +156,8 @@ namespace dlib ...@@ -117,8 +156,8 @@ namespace dlib
else else
recall = correct_hits / total_true_targets; recall = correct_hits / total_true_targets;
matrix<double, 1, 2> res; matrix<double, 1, 3> res;
res = precision, recall; res = precision, recall, map.mean();
return res; return res;
} }
...@@ -126,7 +165,7 @@ namespace dlib ...@@ -126,7 +165,7 @@ namespace dlib
typename object_detector_type, typename object_detector_type,
typename image_array_type typename image_array_type
> >
const matrix<double,1,2> test_object_detection_function ( const matrix<double,1,3> test_object_detection_function (
object_detector_type& detector, object_detector_type& detector,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_dets, const std::vector<std::vector<rectangle> >& truth_dets,
...@@ -197,7 +236,7 @@ namespace dlib ...@@ -197,7 +236,7 @@ namespace dlib
typename trainer_type, typename trainer_type,
typename image_array_type typename image_array_type
> >
const matrix<double,1,2> cross_validate_object_detection_trainer ( const matrix<double,1,3> cross_validate_object_detection_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_dets, const std::vector<std::vector<full_object_detection> >& truth_dets,
...@@ -222,6 +261,7 @@ namespace dlib ...@@ -222,6 +261,7 @@ namespace dlib
const long test_size = images.size()/folds; const long test_size = images.size()/folds;
running_stats<double> map;
unsigned long test_idx = 0; unsigned long test_idx = 0;
for (long iter = 0; iter < folds; ++iter) for (long iter = 0; iter < folds; ++iter)
{ {
...@@ -245,11 +285,26 @@ namespace dlib ...@@ -245,11 +285,26 @@ namespace dlib
typename trainer_type::trained_function_type detector = trainer.train(array_subset, training_rects); typename trainer_type::trained_function_type detector = trainer.train(array_subset, training_rects);
for (unsigned long i = 0; i < test_idx_set.size(); ++i) for (unsigned long i = 0; i < test_idx_set.size(); ++i)
{ {
const std::vector<rectangle>& hits = detector(images[test_idx_set[i]]); std::vector<std::pair<double,rectangle> > all_dets;
detector(images[test_idx_set[i]], all_dets, -std::numeric_limits<double>::infinity());
std::vector<rectangle> hits;
for (unsigned long k = 0; k < all_dets.size(); ++k)
{
if (all_dets[k].first >= 0)
hits.push_back(all_dets[k].second);
}
total_hits += hits.size(); total_hits += hits.size();
correct_hits += impl::number_of_truth_hits(truth_dets[test_idx_set[i]], hits, overlap_eps); correct_hits += impl::number_of_truth_hits(truth_dets[test_idx_set[i]], hits, overlap_eps);
total_true_targets += truth_dets[test_idx_set[i]].size(); total_true_targets += truth_dets[test_idx_set[i]].size();
// now get the average precision
hits.clear();
for (unsigned long k = 0; k < all_dets.size(); ++k)
hits.push_back(all_dets[k].second);
double ap;
impl::number_of_truth_hits(truth_dets[test_idx_set[i]], hits, overlap_eps, ap);
map.add(ap);
} }
} }
...@@ -268,8 +323,8 @@ namespace dlib ...@@ -268,8 +323,8 @@ namespace dlib
else else
recall = correct_hits / total_true_targets; recall = correct_hits / total_true_targets;
matrix<double, 1, 2> res; matrix<double, 1, 3> res;
res = precision, recall; res = precision, recall, map.mean();
return res; return res;
} }
...@@ -277,7 +332,7 @@ namespace dlib ...@@ -277,7 +332,7 @@ namespace dlib
typename trainer_type, typename trainer_type,
typename image_array_type typename image_array_type
> >
const matrix<double,1,2> cross_validate_object_detection_trainer ( const matrix<double,1,3> cross_validate_object_detection_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_dets, const std::vector<std::vector<rectangle> >& truth_dets,
......
...@@ -17,7 +17,7 @@ namespace dlib ...@@ -17,7 +17,7 @@ namespace dlib
typename object_detector_type, typename object_detector_type,
typename image_array_type typename image_array_type
> >
const matrix<double,1,2> test_object_detection_function ( const matrix<double,1,3> test_object_detection_function (
object_detector_type& detector, object_detector_type& detector,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_dets, const std::vector<std::vector<full_object_detection> >& truth_dets,
...@@ -32,9 +32,10 @@ namespace dlib ...@@ -32,9 +32,10 @@ namespace dlib
- image_array_type must be an implementation of dlib/array/array_kernel_abstract.h - image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
and it must contain objects which can be accepted by detector(). and it must contain objects which can be accepted by detector().
ensures ensures
- Tests the given detector against the supplied object detection problem - Tests the given detector against the supplied object detection problem and
and returns the precision and recall. Note that the task is to predict, returns the precision, recall, and mean average precision. Note that the
for each images[i], the set of object locations given by truth_dets[i]. task is to predict, for each images[i], the set of object locations given by
truth_dets[i].
- In particular, returns a matrix M such that: - In particular, returns a matrix M such that:
- M(0) == the precision of the detector object. This is a number - M(0) == the precision of the detector object. This is a number
in the range [0,1] which measures the fraction of detector outputs in the range [0,1] which measures the fraction of detector outputs
...@@ -42,10 +43,18 @@ namespace dlib ...@@ -42,10 +43,18 @@ namespace dlib
never produces any false alarms while a value of 0 means it only never produces any false alarms while a value of 0 means it only
produces false alarms. produces false alarms.
- M(1) == the recall of the detector object. This is a number in the - M(1) == the recall of the detector object. This is a number in the
range [0,1] which measure the fraction of targets found by the range [0,1] which measures the fraction of targets found by the
detector. A value of 1 means the detector found all the targets detector. A value of 1 means the detector found all the targets
in truth_dets while a value of 0 means the detector didn't locate in truth_dets while a value of 0 means the detector didn't locate
any of the targets. any of the targets.
- M(2) == the mean average precision of the detector object. This is a
number in the range [0,1] which measures the overall quality of the
detector when the detector is asked to output a ranked listing of all
possible detections. In particular, this is accomplished by setting the
detection threshold such that all possible detections are output. Then
the detections are ordered by their detection score and we use the
average_precision() routine to score each ranked listing, finally setting
M(2) to the mean value over all test images.
- The rule for deciding if a detector output, D, matches a truth rectangle, - The rule for deciding if a detector output, D, matches a truth rectangle,
T, is the following: T, is the following:
T and R match if and only if: T.intersect(R).area()/(T+R).area() > overlap_eps T and R match if and only if: T.intersect(R).area()/(T+R).area() > overlap_eps
...@@ -55,7 +64,7 @@ namespace dlib ...@@ -55,7 +64,7 @@ namespace dlib
typename object_detector_type, typename object_detector_type,
typename image_array_type typename image_array_type
> >
const matrix<double,1,2> test_object_detection_function ( const matrix<double,1,3> test_object_detection_function (
object_detector_type& detector, object_detector_type& detector,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_dets, const std::vector<std::vector<rectangle> >& truth_dets,
...@@ -77,7 +86,7 @@ namespace dlib ...@@ -77,7 +86,7 @@ namespace dlib
typename trainer_type, typename trainer_type,
typename image_array_type typename image_array_type
> >
const matrix<double,1,2> cross_validate_object_detection_trainer ( const matrix<double,1,3> cross_validate_object_detection_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_dets, const std::vector<std::vector<full_object_detection> >& truth_dets,
...@@ -97,16 +106,16 @@ namespace dlib ...@@ -97,16 +106,16 @@ namespace dlib
- Performs k-fold cross-validation by using the given trainer to solve an - Performs k-fold cross-validation by using the given trainer to solve an
object detection problem for the given number of folds. Each fold is tested object detection problem for the given number of folds. Each fold is tested
using the output of the trainer and a matrix summarizing the results is using the output of the trainer and a matrix summarizing the results is
returned. The matrix contains the precision and recall of the trained returned. The matrix contains the precision, recall, and mean average
detectors and is defined identically to the test_object_detection_function() precision of the trained detectors and is defined identically to the
routine defined at the top of this file. test_object_detection_function() routine defined at the top of this file.
!*/ !*/
template < template <
typename trainer_type, typename trainer_type,
typename image_array_type typename image_array_type
> >
const matrix<double,1,2> cross_validate_object_detection_trainer ( const matrix<double,1,3> cross_validate_object_detection_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_dets, const std::vector<std::vector<rectangle> >& truth_dets,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment