Commit a32002ea authored by Davis King's avatar Davis King

Added overloads of the object detection test tools that work with

both rectangle and full_object_detection objects.
parent 786b1dbd
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "../matrix.h" #include "../matrix.h"
#include "svm.h" #include "svm.h"
#include "../geometry.h" #include "../geometry.h"
#include "../image_processing/full_object_detection.h"
namespace dlib namespace dlib
{ {
...@@ -17,7 +18,7 @@ namespace dlib ...@@ -17,7 +18,7 @@ namespace dlib
namespace impl namespace impl
{ {
inline unsigned long number_of_truth_hits ( inline unsigned long number_of_truth_hits (
const std::vector<rectangle>& truth_boxes, const std::vector<full_object_detection>& truth_boxes,
const std::vector<rectangle>& boxes, const std::vector<rectangle>& boxes,
const double overlap_eps const double overlap_eps
) )
...@@ -47,7 +48,7 @@ namespace dlib ...@@ -47,7 +48,7 @@ namespace dlib
if (used[j]) if (used[j])
continue; continue;
const double overlap = truth_boxes[i].intersect(boxes[j]).area() / (double)(truth_boxes[i]+boxes[j]).area(); const double overlap = truth_boxes[i].rect.intersect(boxes[j]).area() / (double)(truth_boxes[i].rect+boxes[j]).area();
if (overlap > best_overlap) if (overlap > best_overlap)
{ {
best_overlap = overlap; best_overlap = overlap;
...@@ -75,7 +76,7 @@ namespace dlib ...@@ -75,7 +76,7 @@ namespace dlib
const matrix<double,1,2> test_object_detection_function ( const matrix<double,1,2> test_object_detection_function (
object_detector_type& detector, object_detector_type& detector,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects, const std::vector<std::vector<full_object_detection> >& truth_rects,
const double overlap_eps = 0.5 const double overlap_eps = 0.5
) )
{ {
...@@ -121,6 +122,30 @@ namespace dlib ...@@ -121,6 +122,30 @@ namespace dlib
return res; return res;
} }
template <
typename object_detector_type,
typename image_array_type
>
const matrix<double,1,2> test_object_detection_function (
object_detector_type& detector,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects,
const double overlap_eps = 0.5
)
{
// convert into a list of regular rectangles.
std::vector<std::vector<full_object_detection> > rects(truth_rects.size());
for (unsigned long i = 0; i < truth_rects.size(); ++i)
{
for (unsigned long j = 0; j < truth_rects[i].size(); ++j)
{
rects[i].push_back(full_object_detection(truth_rects[i][j]));
}
}
return test_object_detection_function(detector, images, rects, overlap_eps);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
namespace impl namespace impl
...@@ -163,18 +188,18 @@ namespace dlib ...@@ -163,18 +188,18 @@ namespace dlib
const matrix<double,1,2> cross_validate_object_detection_trainer ( const matrix<double,1,2> cross_validate_object_detection_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects, const std::vector<std::vector<full_object_detection> >& truth_object_detections,
const long folds, const long folds,
const double overlap_eps = 0.5 const double overlap_eps = 0.5
) )
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT( is_learning_problem(images,truth_rects) == true && DLIB_ASSERT( is_learning_problem(images,truth_object_detections) == true &&
0 < overlap_eps && overlap_eps <= 1 && 0 < overlap_eps && overlap_eps <= 1 &&
1 < folds && folds <= static_cast<long>(images.size()), 1 < folds && folds <= static_cast<long>(images.size()),
"\t matrix cross_validate_object_detection_trainer()" "\t matrix cross_validate_object_detection_trainer()"
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
<< "\n\t is_learning_problem(images,truth_rects): " << is_learning_problem(images,truth_rects) << "\n\t is_learning_problem(images,truth_object_detections): " << is_learning_problem(images,truth_object_detections)
<< "\n\t overlap_eps: "<< overlap_eps << "\n\t overlap_eps: "<< overlap_eps
<< "\n\t folds: "<< folds << "\n\t folds: "<< folds
); );
...@@ -195,10 +220,10 @@ namespace dlib ...@@ -195,10 +220,10 @@ namespace dlib
test_idx_set.push_back(test_idx++); test_idx_set.push_back(test_idx++);
unsigned long train_idx = test_idx%images.size(); unsigned long train_idx = test_idx%images.size();
std::vector<std::vector<rectangle> > training_rects; std::vector<std::vector<full_object_detection> > training_rects;
for (unsigned long i = 0; i < images.size()-test_size; ++i) for (unsigned long i = 0; i < images.size()-test_size; ++i)
{ {
training_rects.push_back(truth_rects[train_idx]); training_rects.push_back(truth_object_detections[train_idx]);
train_idx_set.push_back(train_idx); train_idx_set.push_back(train_idx);
train_idx = (train_idx+1)%images.size(); train_idx = (train_idx+1)%images.size();
} }
...@@ -211,8 +236,8 @@ namespace dlib ...@@ -211,8 +236,8 @@ namespace dlib
const std::vector<rectangle>& hits = detector(images[test_idx_set[i]]); const std::vector<rectangle>& hits = detector(images[test_idx_set[i]]);
total_hits += hits.size(); total_hits += hits.size();
correct_hits += impl::number_of_truth_hits(truth_rects[test_idx_set[i]], hits, overlap_eps); correct_hits += impl::number_of_truth_hits(truth_object_detections[test_idx_set[i]], hits, overlap_eps);
total_true_targets += truth_rects[test_idx_set[i]].size(); total_true_targets += truth_object_detections[test_idx_set[i]].size();
} }
} }
...@@ -236,6 +261,31 @@ namespace dlib ...@@ -236,6 +261,31 @@ namespace dlib
return res; return res;
} }
template <
typename trainer_type,
typename image_array_type
>
const matrix<double,1,2> cross_validate_object_detection_trainer (
const trainer_type& trainer,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_object_detections,
const long folds,
const double overlap_eps = 0.5
)
{
// convert into a list of regular rectangles.
std::vector<std::vector<full_object_detection> > dets(truth_object_detections.size());
for (unsigned long i = 0; i < truth_object_detections.size(); ++i)
{
for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j)
{
dets[i].push_back(full_object_detection(truth_object_detections[i][j]));
}
}
return cross_validate_object_detection_trainer(trainer, images, dets, folds, overlap_eps);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <vector> #include <vector>
#include "../matrix.h" #include "../matrix.h"
#include "../geometry.h" #include "../geometry.h"
#include "../image_processing/full_object_detection_abstract.h"
namespace dlib namespace dlib
{ {
...@@ -19,7 +20,7 @@ namespace dlib ...@@ -19,7 +20,7 @@ namespace dlib
const matrix<double,1,2> test_object_detection_function ( const matrix<double,1,2> test_object_detection_function (
object_detector_type& detector, object_detector_type& detector,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects, const std::vector<std::vector<full_object_detection> >& truth_rects,
const double overlap_eps = 0.5 const double overlap_eps = 0.5
); );
/*! /*!
...@@ -50,6 +51,26 @@ namespace dlib ...@@ -50,6 +51,26 @@ namespace dlib
T and R match if and only if: T.intersect(R).area()/(T+R).area() > overlap_eps T and R match if and only if: T.intersect(R).area()/(T+R).area() > overlap_eps
!*/ !*/
template <
typename object_detector_type,
typename image_array_type
>
const matrix<double,1,2> test_object_detection_function (
object_detector_type& detector,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects,
const double overlap_eps = 0.5
);
/*!
requires
- all the requirements of the above test_object_detection_function() routine.
ensures
- converts all the rectangles in truth_rects into full_object_detection objects
via full_object_detection's rectangle constructor. Then invokes
test_object_detection_function() on the full_object_detections and returns
the results.
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -59,7 +80,7 @@ namespace dlib ...@@ -59,7 +80,7 @@ namespace dlib
const matrix<double,1,2> cross_validate_object_detection_trainer ( const matrix<double,1,2> cross_validate_object_detection_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects, const std::vector<std::vector<full_object_detection> >& truth_rects,
const long folds, const long folds,
const double overlap_eps = 0.5 const double overlap_eps = 0.5
); );
...@@ -71,6 +92,7 @@ namespace dlib ...@@ -71,6 +92,7 @@ namespace dlib
- trainer_type == some kind of object detection trainer (e.g structural_object_detection_trainer) - trainer_type == some kind of object detection trainer (e.g structural_object_detection_trainer)
- image_array_type must be an implementation of dlib/array/array_kernel_abstract.h - image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
and it must contain objects which can be accepted by detector(). and it must contain objects which can be accepted by detector().
- it is legal to call trainer.train(images, truth_rects)
ensures ensures
- Performs k-fold cross-validation by using the given trainer to solve an - Performs k-fold cross-validation by using the given trainer to solve an
object detection problem for the given number of folds. Each fold is tested object detection problem for the given number of folds. Each fold is tested
...@@ -80,6 +102,26 @@ namespace dlib ...@@ -80,6 +102,26 @@ namespace dlib
routine defined at the top of this file. routine defined at the top of this file.
!*/ !*/
template <
typename trainer_type,
typename image_array_type
>
const matrix<double,1,2> cross_validate_object_detection_trainer (
const trainer_type& trainer,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects,
const long folds,
const double overlap_eps = 0.5
);
/*!
requires
- all the requirements of the above cross_validate_object_detection_trainer() routine.
ensures
- converts all the rectangles in truth_rects into full_object_detection objects
via full_object_detection's rectangle constructor. Then invokes
cross_validate_object_detection_trainer() on the full_object_detections and
returns the results.
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment