Commit a32002ea authored by Davis King's avatar Davis King

Added overloads of the object detection test tools that work with

both rectangle and full_object_detection objects.
parent 786b1dbd
......@@ -8,6 +8,7 @@
#include "../matrix.h"
#include "svm.h"
#include "../geometry.h"
#include "../image_processing/full_object_detection.h"
namespace dlib
{
......@@ -17,7 +18,7 @@ namespace dlib
namespace impl
{
inline unsigned long number_of_truth_hits (
const std::vector<rectangle>& truth_boxes,
const std::vector<full_object_detection>& truth_boxes,
const std::vector<rectangle>& boxes,
const double overlap_eps
)
......@@ -47,7 +48,7 @@ namespace dlib
if (used[j])
continue;
const double overlap = truth_boxes[i].intersect(boxes[j]).area() / (double)(truth_boxes[i]+boxes[j]).area();
const double overlap = truth_boxes[i].rect.intersect(boxes[j]).area() / (double)(truth_boxes[i].rect+boxes[j]).area();
if (overlap > best_overlap)
{
best_overlap = overlap;
......@@ -75,7 +76,7 @@ namespace dlib
const matrix<double,1,2> test_object_detection_function (
object_detector_type& detector,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects,
const std::vector<std::vector<full_object_detection> >& truth_rects,
const double overlap_eps = 0.5
)
{
......@@ -121,6 +122,30 @@ namespace dlib
return res;
}
template <
typename object_detector_type,
typename image_array_type
>
const matrix<double,1,2> test_object_detection_function (
object_detector_type& detector,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects,
const double overlap_eps = 0.5
)
{
// convert into a list of regular rectangles.
std::vector<std::vector<full_object_detection> > rects(truth_rects.size());
for (unsigned long i = 0; i < truth_rects.size(); ++i)
{
for (unsigned long j = 0; j < truth_rects[i].size(); ++j)
{
rects[i].push_back(full_object_detection(truth_rects[i][j]));
}
}
return test_object_detection_function(detector, images, rects, overlap_eps);
}
// ----------------------------------------------------------------------------------------
namespace impl
......@@ -163,18 +188,18 @@ namespace dlib
const matrix<double,1,2> cross_validate_object_detection_trainer (
const trainer_type& trainer,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects,
const std::vector<std::vector<full_object_detection> >& truth_object_detections,
const long folds,
const double overlap_eps = 0.5
)
{
// make sure requires clause is not broken
DLIB_ASSERT( is_learning_problem(images,truth_rects) == true &&
DLIB_ASSERT( is_learning_problem(images,truth_object_detections) == true &&
0 < overlap_eps && overlap_eps <= 1 &&
1 < folds && folds <= static_cast<long>(images.size()),
"\t matrix cross_validate_object_detection_trainer()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t is_learning_problem(images,truth_rects): " << is_learning_problem(images,truth_rects)
<< "\n\t is_learning_problem(images,truth_object_detections): " << is_learning_problem(images,truth_object_detections)
<< "\n\t overlap_eps: "<< overlap_eps
<< "\n\t folds: "<< folds
);
......@@ -195,10 +220,10 @@ namespace dlib
test_idx_set.push_back(test_idx++);
unsigned long train_idx = test_idx%images.size();
std::vector<std::vector<rectangle> > training_rects;
std::vector<std::vector<full_object_detection> > training_rects;
for (unsigned long i = 0; i < images.size()-test_size; ++i)
{
training_rects.push_back(truth_rects[train_idx]);
training_rects.push_back(truth_object_detections[train_idx]);
train_idx_set.push_back(train_idx);
train_idx = (train_idx+1)%images.size();
}
......@@ -211,8 +236,8 @@ namespace dlib
const std::vector<rectangle>& hits = detector(images[test_idx_set[i]]);
total_hits += hits.size();
correct_hits += impl::number_of_truth_hits(truth_rects[test_idx_set[i]], hits, overlap_eps);
total_true_targets += truth_rects[test_idx_set[i]].size();
correct_hits += impl::number_of_truth_hits(truth_object_detections[test_idx_set[i]], hits, overlap_eps);
total_true_targets += truth_object_detections[test_idx_set[i]].size();
}
}
......@@ -236,6 +261,31 @@ namespace dlib
return res;
}
template <
typename trainer_type,
typename image_array_type
>
const matrix<double,1,2> cross_validate_object_detection_trainer (
const trainer_type& trainer,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_object_detections,
const long folds,
const double overlap_eps = 0.5
)
{
// convert into a list of regular rectangles.
std::vector<std::vector<full_object_detection> > dets(truth_object_detections.size());
for (unsigned long i = 0; i < truth_object_detections.size(); ++i)
{
for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j)
{
dets[i].push_back(full_object_detection(truth_object_detections[i][j]));
}
}
return cross_validate_object_detection_trainer(trainer, images, dets, folds, overlap_eps);
}
// ----------------------------------------------------------------------------------------
}
......
......@@ -6,6 +6,7 @@
#include <vector>
#include "../matrix.h"
#include "../geometry.h"
#include "../image_processing/full_object_detection_abstract.h"
namespace dlib
{
......@@ -19,7 +20,7 @@ namespace dlib
const matrix<double,1,2> test_object_detection_function (
object_detector_type& detector,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects,
const std::vector<std::vector<full_object_detection> >& truth_rects,
const double overlap_eps = 0.5
);
/*!
......@@ -50,6 +51,26 @@ namespace dlib
T and R match if and only if: T.intersect(R).area()/(T+R).area() > overlap_eps
!*/
template <
typename object_detector_type,
typename image_array_type
>
const matrix<double,1,2> test_object_detection_function (
object_detector_type& detector,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects,
const double overlap_eps = 0.5
);
/*!
requires
- all the requirements of the above test_object_detection_function() routine.
ensures
- converts all the rectangles in truth_rects into full_object_detection objects
via full_object_detection's rectangle constructor. Then invokes
test_object_detection_function() on the full_object_detections and returns
the results.
!*/
// ----------------------------------------------------------------------------------------
template <
......@@ -59,7 +80,7 @@ namespace dlib
const matrix<double,1,2> cross_validate_object_detection_trainer (
const trainer_type& trainer,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects,
const std::vector<std::vector<full_object_detection> >& truth_rects,
const long folds,
const double overlap_eps = 0.5
);
......@@ -71,6 +92,7 @@ namespace dlib
- trainer_type == some kind of object detection trainer (e.g structural_object_detection_trainer)
- image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
and it must contain objects which can be accepted by detector().
- it is legal to call trainer.train(images, truth_rects)
ensures
- Performs k-fold cross-validation by using the given trainer to solve an
object detection problem for the given number of folds. Each fold is tested
......@@ -80,6 +102,26 @@ namespace dlib
routine defined at the top of this file.
!*/
template <
typename trainer_type,
typename image_array_type
>
const matrix<double,1,2> cross_validate_object_detection_trainer (
const trainer_type& trainer,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects,
const long folds,
const double overlap_eps = 0.5
);
/*!
requires
- all the requirements of the above cross_validate_object_detection_trainer() routine.
ensures
- converts all the rectangles in truth_rects into full_object_detection objects
via full_object_detection's rectangle constructor. Then invokes
cross_validate_object_detection_trainer() on the full_object_detections and
returns the results.
!*/
// ----------------------------------------------------------------------------------------
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment