Commit 9d3e40a6 authored by Davis King's avatar Davis King

Added support for ignore rectangles into the object detection testing functions. I also

changed the interfaces to these functions slightly.  Instead of taking a double that
determines how we decide if boxes match, they now take a test_box_overlap object.
parent f53da5ed
......@@ -9,6 +9,7 @@
#include "svm.h"
#include "../geometry.h"
#include "../image_processing/full_object_detection.h"
#include "../image_processing/box_overlap_testing.h"
#include "../statistics.h"
namespace dlib
......@@ -21,18 +22,15 @@ namespace dlib
inline unsigned long number_of_truth_hits (
const std::vector<full_object_detection>& truth_boxes,
const std::vector<std::pair<double,rectangle> >& boxes,
const double overlap_eps,
const test_box_overlap& overlap_tester,
std::vector<std::pair<double,bool> >& all_dets,
unsigned long& missing_detections
)
/*!
requires
- 0 < overlap_eps <= 1
ensures
- returns the number of elements in truth_boxes which are overlapped by an
element of boxes. In this context, two boxes, A and B, overlap if and only if
the following quantity is greater than overlap_eps:
A.intersect(B).area()/(A+B).area()
overlap_tester(A,B) == true.
- No element of boxes is allowed to account for more than one element of truth_boxes.
- The returned number is in the range [0,truth_boxes.size()]
- Adds the score for each box from boxes into all_dets and labels each with
......@@ -57,10 +55,7 @@ namespace dlib
if (used[j])
continue;
const double overlap = truth_boxes[i].get_rect().intersect(boxes[j].second).area() /
(double)(truth_boxes[i].get_rect()+boxes[j].second).area();
if (overlap >= overlap_eps)
if (overlap_tester(truth_boxes[i].get_rect(), boxes[j].second))
{
used[j] = true;
++count;
......@@ -81,6 +76,25 @@ namespace dlib
return count;
}
// ------------------------------------------------------------------------------------
inline unsigned long count_dets_not_hitting_ignore_boxes (
const test_box_overlap& overlap_tester,
const std::vector<rectangle>& ignore,
const std::vector<std::pair<double,rectangle> >& dets
)
{
unsigned long count = 0;
for (unsigned long i = 0; i < dets.size(); ++i)
{
if (!overlaps_any_box(overlap_tester, ignore, dets[i].second))
{
++count;
}
}
return count;
}
}
// ----------------------------------------------------------------------------------------
......@@ -93,17 +107,19 @@ namespace dlib
object_detector_type& detector,
const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_dets,
const double overlap_eps = 0.5,
const std::vector<std::vector<rectangle> >& ignore,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
)
{
// make sure requires clause is not broken
DLIB_CASSERT( is_learning_problem(images,truth_dets) == true &&
0 < overlap_eps && overlap_eps <= 1,
ignore.size() == images.size(),
"\t matrix test_object_detection_function()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets)
<< "\n\t overlap_eps: "<< overlap_eps
<< "\n\t ignore.size(): " << ignore.size()
<< "\n\t images.size(): " << images.size()
);
......@@ -121,8 +137,8 @@ namespace dlib
std::vector<std::pair<double,rectangle> > hits;
detector(images[i], hits, adjust_threshold);
total_hits += hits.size();
correct_hits += impl::number_of_truth_hits(truth_dets[i], hits, overlap_eps, all_dets, missing_detections);
total_hits += impl::count_dets_not_hitting_ignore_boxes(overlap_tester, ignore[i], hits);
correct_hits += impl::number_of_truth_hits(truth_dets[i], hits, overlap_tester, all_dets, missing_detections);
total_true_targets += truth_dets[i].size();
}
......@@ -130,6 +146,11 @@ namespace dlib
double precision, recall;
// If the user put an ignore box on the same spot as a truth box then we could end
// up with a total_hits value less than correct_hits. So we do this to make sure
// the precision value never goes above 1.
total_hits = std::max(total_hits,correct_hits);
if (total_hits == 0)
precision = 1;
else
......@@ -153,7 +174,8 @@ namespace dlib
object_detector_type& detector,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_dets,
const double overlap_eps = 0.5,
const std::vector<std::vector<rectangle> >& ignore,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
)
{
......@@ -167,9 +189,44 @@ namespace dlib
}
}
return test_object_detection_function(detector, images, rects, overlap_eps, adjust_threshold);
return test_object_detection_function(detector, images, rects, ignore, overlap_tester, adjust_threshold);
}
template <
typename object_detector_type,
typename image_array_type
>
const matrix<double,1,3> test_object_detection_function (
object_detector_type& detector,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_dets,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
)
{
std::vector<std::vector<rectangle> > ignore(images.size());
return test_object_detection_function(detector,images,truth_dets,ignore, overlap_tester, adjust_threshold);
}
template <
typename object_detector_type,
typename image_array_type
>
const matrix<double,1,3> test_object_detection_function (
object_detector_type& detector,
const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_dets,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
)
{
std::vector<std::vector<rectangle> > ignore(images.size());
return test_object_detection_function(detector,images,truth_dets,ignore, overlap_tester, adjust_threshold);
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
namespace impl
......@@ -225,20 +282,22 @@ namespace dlib
const trainer_type& trainer,
const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_dets,
const std::vector<std::vector<rectangle> >& ignore,
const long folds,
const double overlap_eps = 0.5,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
)
{
// make sure requires clause is not broken
DLIB_CASSERT( is_learning_problem(images,truth_dets) == true &&
0 < overlap_eps && overlap_eps <= 1 &&
ignore.size() == images.size() &&
1 < folds && folds <= static_cast<long>(images.size()),
"\t matrix cross_validate_object_detection_trainer()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets)
<< "\n\t overlap_eps: "<< overlap_eps
<< "\n\t folds: "<< folds
<< "\n\t ignore.size(): " << ignore.size()
<< "\n\t images.size(): " << images.size()
);
double correct_hits = 0;
......@@ -260,23 +319,25 @@ namespace dlib
unsigned long train_idx = test_idx%images.size();
std::vector<std::vector<full_object_detection> > training_rects;
std::vector<std::vector<rectangle> > training_ignores;
for (unsigned long i = 0; i < images.size()-test_size; ++i)
{
training_rects.push_back(truth_dets[train_idx]);
training_ignores.push_back(ignore[train_idx]);
train_idx_set.push_back(train_idx);
train_idx = (train_idx+1)%images.size();
}
impl::array_subset_helper<image_array_type> array_subset(images, train_idx_set);
typename trainer_type::trained_function_type detector = trainer.train(array_subset, training_rects);
typename trainer_type::trained_function_type detector = trainer.train(array_subset, training_rects, training_ignores, overlap_tester);
for (unsigned long i = 0; i < test_idx_set.size(); ++i)
{
std::vector<std::pair<double,rectangle> > hits;
detector(images[test_idx_set[i]], hits, adjust_threshold);
total_hits += hits.size();
correct_hits += impl::number_of_truth_hits(truth_dets[test_idx_set[i]], hits, overlap_eps, all_dets, missing_detections);
total_hits += impl::count_dets_not_hitting_ignore_boxes(overlap_tester, ignore[i], hits);
correct_hits += impl::number_of_truth_hits(truth_dets[test_idx_set[i]], hits, overlap_tester, all_dets, missing_detections);
total_true_targets += truth_dets[test_idx_set[i]].size();
}
......@@ -287,6 +348,11 @@ namespace dlib
double precision, recall;
// If the user put an ignore box on the same spot as a truth box then we could end
// up with a total_hits value less than correct_hits. So we do this to make sure
// the precision value never goes above 1.
total_hits = std::max(total_hits,correct_hits);
if (total_hits == 0)
precision = 1;
else
......@@ -310,8 +376,9 @@ namespace dlib
const trainer_type& trainer,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_dets,
const std::vector<std::vector<rectangle> >& ignore,
const long folds,
const double overlap_eps = 0.5,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
)
{
......@@ -325,7 +392,41 @@ namespace dlib
}
}
return cross_validate_object_detection_trainer(trainer, images, dets, folds, overlap_eps, adjust_threshold);
return cross_validate_object_detection_trainer(trainer, images, dets, ignore, folds, overlap_tester, adjust_threshold);
}
template <
typename trainer_type,
typename image_array_type
>
const matrix<double,1,3> cross_validate_object_detection_trainer (
const trainer_type& trainer,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_dets,
const long folds,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
)
{
const std::vector<std::vector<rectangle> > ignore(images.size());
return cross_validate_object_detection_trainer(trainer,images,truth_dets,ignore,folds,overlap_tester,adjust_threshold);
}
template <
typename trainer_type,
typename image_array_type
>
const matrix<double,1,3> cross_validate_object_detection_trainer (
const trainer_type& trainer,
const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_dets,
const long folds,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
)
{
const std::vector<std::vector<rectangle> > ignore(images.size());
return cross_validate_object_detection_trainer(trainer,images,truth_dets,ignore,folds,overlap_tester,adjust_threshold);
}
// ----------------------------------------------------------------------------------------
......
......@@ -21,13 +21,14 @@ namespace dlib
object_detector_type& detector,
const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_dets,
const double overlap_eps = 0.5,
const std::vector<std::vector<rectangle> >& ignore,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
);
/*!
requires
- is_learning_problem(images,truth_dets)
- 0 < overlap_eps <= 1
- images.size() == ignore.size()
- object_detector_type == some kind of object detector function object
(e.g. object_detector)
- image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
......@@ -35,7 +36,12 @@ namespace dlib
ensures
- Tests the given detector against the supplied object detection problem and
returns the precision, recall, and average precision. Note that the task is
to predict, for each images[i], the set of object locations given by truth_dets[i].
to predict, for each images[i], the set of object locations given by
truth_dets[i]. Additionally, any detections on image[i] that match a box in
ignore[i] are ignored. That is, detections matching a box in ignore[i] do
not count as a false alarm and similarly if any element of ignore[i] goes
undetected it does not count as a missed detection. So we say that ignore[i]
contains a set of boxes that we "don't care" if they are detected or not.
- In particular, returns a matrix M such that:
- M(0) == the precision of the detector object. This is a number
in the range [0,1] which measures the fraction of detector outputs
......@@ -53,9 +59,8 @@ namespace dlib
ordering them in descending order of their detection scores. Then we use
the average_precision() routine to score the ranked listing and store the
output into M(2).
- The rule for deciding if a detector output, D, matches a truth rectangle,
T, is the following:
T and R match if and only if: T.intersect(R).area()/(T+R).area() > overlap_eps
- This function considers a detector output D to match a rectangle T if and
only if overlap_tester(T,D) returns true.
- Note that you can use the adjust_threshold argument to raise or lower the
detection threshold. This value is passed into the identically named
argument to the detector object and therefore influences the number of
......@@ -73,12 +78,13 @@ namespace dlib
object_detector_type& detector,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_dets,
const double overlap_eps = 0.5,
const std::vector<std::vector<rectangle> >& ignore,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
);
/*!
requires
- all the requirements of the above test_object_detection_function() routine.
- All the requirements of the above test_object_detection_function() routine.
ensures
- converts all the rectangles in truth_dets into full_object_detection objects
via full_object_detection's rectangle constructor. Then invokes
......@@ -86,6 +92,46 @@ namespace dlib
the results.
!*/
template <
typename object_detector_type,
typename image_array_type
>
const matrix<double,1,3> test_object_detection_function (
object_detector_type& detector,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_dets,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
);
/*!
requires
- All the requirements of the above test_object_detection_function() routine.
ensures
- This function simply invokes test_object_detection_function() with all the
given arguments and an empty set of ignore rectangles and returns the results.
!*/
template <
typename object_detector_type,
typename image_array_type
>
const matrix<double,1,3> test_object_detection_function (
object_detector_type& detector,
const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_dets,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
);
/*!
requires
- All the requirements of the above test_object_detection_function() routine.
ensures
- This function simply invokes test_object_detection_function() with all the
given arguments and an empty set of ignore rectangles and returns the results.
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
......@@ -96,14 +142,15 @@ namespace dlib
const trainer_type& trainer,
const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_dets,
const std::vector<std::vector<rectangle> >& ignore,
const long folds,
const double overlap_eps = 0.5,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
);
/*!
requires
- is_learning_problem(images,truth_dets)
- 0 < overlap_eps <= 1
- images.size() == ignore.size()
- 1 < folds <= images.size()
- trainer_type == some kind of object detection trainer (e.g structural_object_detection_trainer)
- image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
......@@ -126,8 +173,9 @@ namespace dlib
const trainer_type& trainer,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_dets,
const std::vector<std::vector<rectangle> >& ignore,
const long folds,
const double overlap_eps = 0.5,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
);
/*!
......@@ -139,6 +187,47 @@ namespace dlib
cross_validate_object_detection_trainer() on the full_object_detections and
returns the results.
!*/
template <
typename trainer_type,
typename image_array_type
>
const matrix<double,1,3> cross_validate_object_detection_trainer (
const trainer_type& trainer,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_dets,
const long folds,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
);
/*!
requires
- All the requirements of the above cross_validate_object_detection_trainer() routine.
ensures
- This function simply invokes cross_validate_object_detection_trainer() with all
the given arguments and an empty set of ignore rectangles and returns the results.
!*/
template <
typename trainer_type,
typename image_array_type
>
const matrix<double,1,3> cross_validate_object_detection_trainer (
const trainer_type& trainer,
const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_dets,
const long folds,
const test_box_overlap& overlap_tester = test_box_overlap(),
const double adjust_threshold = 0
);
/*!
requires
- All the requirements of the above cross_validate_object_detection_trainer() routine.
ensures
- This function simply invokes cross_validate_object_detection_trainer() with all
the given arguments and an empty set of ignore rectangles and returns the results.
!*/
// ----------------------------------------------------------------------------------------
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment