Commit 3eb0d973 authored by Davis King's avatar Davis King

Added the cross_validate_object_detection_trainer() and test_object_detection_function()

routines.
parent 0aa89e07
......@@ -33,6 +33,7 @@
#include "svm/multiclass_tools.h"
#include "svm/cross_validate_multiclass_trainer.h"
#include "svm/cross_validate_regression_trainer.h"
#include "svm/cross_validate_object_detection_trainer.h"
#include "svm/one_vs_all_decision_function.h"
#include "svm/one_vs_all_trainer.h"
......
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_H__
#define DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_H__
#include "cross_validate_object_detection_trainer_abstract.h"
#include <vector>
#include "../matrix.h"
#include "svm.h"
#include "../geometry.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
namespace impl
{
unsigned long number_of_truth_hits (
const std::vector<rectangle>& truth_boxes,
const std::vector<rectangle>& boxes,
const double overlap_eps
)
/*!
requires
- 0 < overlap_eps <= 1
ensures
- returns the number of elements in truth_boxes which are overlapped by an
element of boxes. In this context, two boxes, A and B, overlap if and only if
the following quantity is greater than overlap_eps:
A.intersect(B).area()/(A+B).area()
- No element of boxes is allowed to account for more than one element of truth_boxes.
- The returned number is in the range [0,truth_boxes.size()]
!*/
{
if (boxes.size() == 0)
return 0;
unsigned long count = 0;
std::vector<bool> used(boxes.size(),false);
for (unsigned long i = 0; i < truth_boxes.size(); ++i)
{
unsigned long best_idx = 0;
double best_overlap = 0;
for (unsigned long j = 0; j < boxes.size(); ++j)
{
if (used[j])
continue;
const double overlap = truth_boxes[i].intersect(boxes[j]).area() / (double)(truth_boxes[i]+boxes[j]).area();
if (overlap > best_overlap)
{
best_overlap = overlap;
best_idx = j;
}
}
if (best_overlap > overlap_eps && used[best_idx] == false)
{
used[best_idx] = true;
++count;
}
}
return count;
}
}
// ----------------------------------------------------------------------------------------
template <
typename object_detector_type,
typename image_array_type
>
const matrix<double,1,2> test_object_detection_function (
const object_detector_type& detector,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects,
const double overlap_eps = 0.5
)
{
// make sure requires clause is not broken
DLIB_ASSERT( is_learning_problem(images,truth_rects) == true &&
0 < overlap_eps && overlap_eps <= 1,
"\t matrix test_object_detection_function()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t is_learning_problem(images,truth_rects): " << is_learning_problem(images,truth_rects)
<< "\n\t overlap_eps: "<< overlap_eps
);
double correct_hits = 0;
double total_hits = 0;
double total_true_targets = 0;
for (unsigned long i = 0; i < images.size(); ++i)
{
const std::vector<rectangle>& hits = detector(images[i]);
total_hits += hits.size();
correct_hits += impl::number_of_truth_hits(truth_rects[i], hits, overlap_eps);
total_true_targets += truth_rects[i].size();
}
double precision, recall;
if (total_hits == 0)
precision = 1;
else
precision = correct_hits / total_hits;
if (total_true_targets == 0)
recall = 1;
else
recall = correct_hits / total_true_targets;
matrix<double, 1, 2> res;
res = precision, recall;
return res;
}
// ----------------------------------------------------------------------------------------
namespace impl
{
template <
typename array_type
>
struct array_subset_helper
{
array_subset_helper (
const array_type& array_,
const std::vector<unsigned long>& idx_set_
) :
array(array_),
idx_set(idx_set_)
{
}
unsigned long size() const { return idx_set.size(); }
typedef typename array_type::type type;
const type& operator[] (
unsigned long idx
) const { return array[idx_set[idx]]; }
private:
const array_type& array;
const std::vector<unsigned long>& idx_set;
};
}
// ----------------------------------------------------------------------------------------
template <
typename trainer_type,
typename image_array_type
>
const matrix<double,1,2> cross_validate_object_detection_trainer (
const trainer_type& trainer,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects,
const long folds,
const double overlap_eps = 0.5
)
{
// make sure requires clause is not broken
DLIB_ASSERT( is_learning_problem(images,truth_rects) == true &&
0 < overlap_eps && overlap_eps <= 1 &&
1 < folds && folds <= images.size(),
"\t matrix cross_validate_object_detection_trainer()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t is_learning_problem(images,truth_rects): " << is_learning_problem(images,truth_rects)
<< "\n\t overlap_eps: "<< overlap_eps
<< "\n\t folds: "<< folds
);
double correct_hits = 0;
double total_hits = 0;
double total_true_targets = 0;
const long test_size = images.size()/folds;
unsigned long test_idx = 0;
for (long iter = 0; iter < folds; ++iter)
{
std::vector<unsigned long> train_idx_set;
std::vector<unsigned long> test_idx_set;
for (unsigned long i = 0; i < test_size; ++i)
test_idx_set.push_back(test_idx++);
unsigned long train_idx = test_idx%images.size();
std::vector<std::vector<rectangle> > training_rects;
for (unsigned long i = 0; i < images.size()-test_size; ++i)
{
training_rects.push_back(truth_rects[train_idx]);
train_idx_set.push_back(train_idx);
train_idx = (train_idx+1)%images.size();
}
impl::array_subset_helper<image_array_type> array_subset(images, train_idx_set);
const typename trainer_type::trained_function_type& detector = trainer.train(array_subset, training_rects);
for (unsigned long i = 0; i < test_idx_set.size(); ++i)
{
const std::vector<rectangle>& hits = detector(images[test_idx_set[i]]);
total_hits += hits.size();
correct_hits += impl::number_of_truth_hits(truth_rects[test_idx_set[i]], hits, overlap_eps);
total_true_targets += truth_rects[test_idx_set[i]].size();
}
}
double precision, recall;
if (total_hits == 0)
precision = 1;
else
precision = correct_hits / total_hits;
if (total_true_targets == 0)
recall = 1;
else
recall = correct_hits / total_true_targets;
matrix<double, 1, 2> res;
res = precision, recall;
return res;
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_H__
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_ABSTRACT_H__
#ifdef DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_ABSTRACT_H__
#include <vector>
#include "../matrix.h"
#include "../geometry.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename object_detector_type,
typename image_array_type
>
const matrix<double,1,2> test_object_detection_function (
const object_detector_type& detector,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects,
const double overlap_eps = 0.5
);
/*!
requires
- is_learning_problem(images,truth_rects)
- 0 < overlap_eps <= 1
- object_detector_type == some kind of object detector function object
(e.g. object_detector)
- image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
and it must contain objects which can be accepted by detector().
ensures
- Tests the given detector against the supplied object detection problem
and returns the precision and recall. Note that the task is to predict,
for each images[i], the set of object locations given by truth_rects[i].
- In particular, returns a matrix M such that:
- M(0) == the precision of the detector object. This is a number
in the range [0,1] which measures the fraction of detector outputs
which correspond to a real target. A value of 1 means the detector
never produces any false alarms while a value of 0 means it only
produces false alarms.
- M(1) == the recall of the detector object. This is a number in the
range [0,1] which measure the fraction of targets found by the
detector. A value of 1 means the detector found all the targets
in truth_rects while a value of 0 means the detector didn't locate
any of the targets.
- The rule for deciding if a detector output, D, matches a truth rectangle,
T, is the following:
T and R match if and only if: T.intersect(R).area()/(T+R).area() > overlap_eps
!*/
// ----------------------------------------------------------------------------------------
template <
typename trainer_type,
typename image_array_type
>
const matrix<double,1,2> cross_validate_object_detection_trainer (
const trainer_type& trainer,
const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects,
const long folds,
const double overlap_eps = 0.5
);
/*!
requires
- is_learning_problem(images,truth_rects)
- 0 < overlap_eps <= 1
- 1 < folds <= images.size()
- trainer_type == some kind of object detection trainer (e.g structural_object_detection_trainer)
- image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
and it must contain objects which can be accepted by detector().
ensures
- Performs k-fold cross-validation by using the given trainer to solve an
object detection problem for the given number of folds. Each fold is tested
using the output of the trainer and a matrix summarizing the results is
returned. The matrix contains the precision and recall of the trained
detectors and is defined identically to the test_object_detection_function()
routine defined at the top of this file.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_CROSS_VALIDATE_OBJECT_DETECTION_TRaINER_ABSTRACT_H__
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment