Commit a6b44e0e authored by Davis King's avatar Davis King

Cleaned up the full_object_detection's interface and improved some comments

here and there.
parent 3bcab68a
// Copyright (C) 2011 Davis E. King (davis@dlib.net) // Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license. // License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_FULL_OBJECT_DeTECTION_H__ #ifndef DLIB_FULL_OBJECT_DeTECTION_H__
#define DLIB_FULL_OBJECT_DeTECTION_H__ #define DLIB_FULL_OBJECT_DeTECTION_H__
...@@ -12,17 +12,18 @@ namespace dlib ...@@ -12,17 +12,18 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
const static point MOVABLE_PART_NOT_PRESENT(0x7FFFFFFF, const static point OBJECT_PART_NOT_PRESENT(0x7FFFFFFF,
0x7FFFFFFF); 0x7FFFFFFF);
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
struct full_object_detection class full_object_detection
{ {
public:
full_object_detection( full_object_detection(
const rectangle& rect_, const rectangle& rect_,
const std::vector<point>& movable_parts_ const std::vector<point>& parts_
) : rect(rect_), movable_parts(movable_parts_) {} ) : rect(rect_), parts(parts_) {}
full_object_detection(){} full_object_detection(){}
...@@ -30,8 +31,27 @@ namespace dlib ...@@ -30,8 +31,27 @@ namespace dlib
const rectangle& rect_ const rectangle& rect_
) : rect(rect_) {} ) : rect(rect_) {}
const rectangle& get_rect() const { return rect; }
unsigned long num_parts() const { return parts.size(); }
const point& part(
unsigned long idx
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(idx < num_parts(),
"\t point full_object_detection::part()"
<< "\n\t Invalid inputs were given to this function "
<< "\n\t idx: " << idx
<< "\n\t num_parts(): " << num_parts()
<< "\n\t this: " << this
);
return parts[idx];
}
private:
rectangle rect; rectangle rect;
std::vector<point> movable_parts; // it should always be the case that rect.contains(movable_parts[i]) == true std::vector<point> parts;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -40,10 +60,10 @@ namespace dlib ...@@ -40,10 +60,10 @@ namespace dlib
const full_object_detection& obj const full_object_detection& obj
) )
{ {
for (unsigned long i = 0; i < obj.movable_parts.size(); ++i) for (unsigned long i = 0; i < obj.num_parts(); ++i)
{ {
if (obj.rect.contains(obj.movable_parts[i]) == false && if (obj.get_rect().contains(obj.part(i)) == false &&
obj.movable_parts[i] != MOVABLE_PART_NOT_PRESENT) obj.part(i) != OBJECT_PART_NOT_PRESENT)
return false; return false;
} }
return true; return true;
......
// Copyright (C) 2011 Davis E. King (davis@dlib.net) // Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license. // License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_FULL_OBJECT_DeTECTION_ABSTRACT_H__ #undef DLIB_FULL_OBJECT_DeTECTION_ABSTRACT_H__
#ifdef DLIB_FULL_OBJECT_DeTECTION_ABSTRACT_H__ #ifdef DLIB_FULL_OBJECT_DeTECTION_ABSTRACT_H__
...@@ -11,26 +11,92 @@ namespace dlib ...@@ -11,26 +11,92 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
const static point MOVABLE_PART_NOT_PRESENT(0x7FFFFFFF, const static point OBJECT_PART_NOT_PRESENT(0x7FFFFFFF,
0x7FFFFFFF); 0x7FFFFFFF);
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
struct full_object_detection class full_object_detection
{ {
/*!
WHAT THIS OBJECT REPRESENTS
This object represents the location of an object in an image along with the
positions of each of its constituent parts.
!*/
public:
full_object_detection( full_object_detection(
const rectangle& rect_, const rectangle& rect,
const std::vector<point>& movable_parts_ const std::vector<point>& parts
) : rect(rect_), movable_parts(movable_parts_) {} );
/*!
ensures
- #get_rect() == rect
- #num_parts() == parts.size()
- for all valid i:
- part(i) == parts[i]
!*/
full_object_detection(
);
/*!
ensures
- #get_rect().is_empty() == true
- #num_parts() == 0
!*/
explicit full_object_detection( explicit full_object_detection(
const rectangle& rect_ const rectangle& rect
) : rect(rect_) {} );
/*!
ensures
- #get_rect() == rect
- #num_parts() == 0
!*/
const rectangle& get_rect(
) const;
/*!
ensures
- returns the rectangle that indicates where this object is. In general,
this should be the bounding box for the object.
!*/
rectangle rect; unsigned long num_parts(
std::vector<point> movable_parts; // it should always be the case that rect.contains(movable_parts[i]) == true ) const;
/*!
ensures
- returns the number of parts in this object.
!*/
const point& part(
unsigned long idx
) const;
/*!
requires
- idx < num_parts()
ensures
- returns the location of the center of the idx-th part of this object.
Note that it is valid for a part to be "not present". This is indicated
when the return value of part() is equal to OBJECT_PART_NOT_PRESENT.
This is useful for modeling object parts that are not always observed.
!*/
}; };
// ----------------------------------------------------------------------------------------
bool all_parts_in_rect (
const full_object_detection& obj
);
/*!
ensures
- returns true if all the parts in obj are contained within obj.get_rect().
That is, returns true if and only if, for all valid i, the following is
always true:
obj.get_rect().contains(obj.parts()[i]) == true || obj.parts()[i] == OBJECT_PART_NOT_PRESENT
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
...@@ -838,7 +838,6 @@ namespace dlib ...@@ -838,7 +838,6 @@ namespace dlib
const feature_vector_type& w const feature_vector_type& w
) const ) const
{ {
full_object_detection obj(rect);
// fill in movable part positions. // fill in movable part positions.
rectangle mapped_rect; rectangle mapped_rect;
...@@ -855,6 +854,8 @@ namespace dlib ...@@ -855,6 +854,8 @@ namespace dlib
// convert into feature space. // convert into feature space.
object_box = object_box.intersect(get_rect(feats[best_level])); object_box = object_box.intersect(get_rect(feats[best_level]));
std::vector<point> movable_parts;
movable_parts.reserve(get_num_movable_components_per_detection_template());
for (unsigned long i = 0; i < get_num_movable_components_per_detection_template(); ++i) for (unsigned long i = 0; i < get_num_movable_components_per_detection_template(); ++i)
{ {
// make the saliency_image for the ith movable part. // make the saliency_image for the ith movable part.
...@@ -912,7 +913,7 @@ namespace dlib ...@@ -912,7 +913,7 @@ namespace dlib
if (max_val <= 0) if (max_val <= 0)
{ {
max_loc = MOVABLE_PART_NOT_PRESENT; max_loc = OBJECT_PART_NOT_PRESENT;
} }
else else
{ {
...@@ -923,13 +924,13 @@ namespace dlib ...@@ -923,13 +924,13 @@ namespace dlib
// now convert from feature space to image space. // now convert from feature space to image space.
max_loc = feats[best_level].feat_to_image_space(max_loc); max_loc = feats[best_level].feat_to_image_space(max_loc);
max_loc = pyr.point_up(max_loc, best_level); max_loc = pyr.point_up(max_loc, best_level);
max_loc = nearest_point(obj.rect, max_loc); max_loc = nearest_point(rect, max_loc);
} }
obj.movable_parts.push_back(max_loc); movable_parts.push_back(max_loc);
} }
return obj; return full_object_detection(rect, movable_parts);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -948,7 +949,7 @@ namespace dlib ...@@ -948,7 +949,7 @@ namespace dlib
DLIB_ASSERT(get_num_detection_templates() > 0 && DLIB_ASSERT(get_num_detection_templates() > 0 &&
is_loaded_with_image() && is_loaded_with_image() &&
psi.size() >= get_num_dimensions() && psi.size() >= get_num_dimensions() &&
obj.movable_parts.size() == get_num_movable_components_per_detection_template(), obj.num_parts() == get_num_movable_components_per_detection_template(),
"\t void scan_image_pyramid::get_feature_vector()" "\t void scan_image_pyramid::get_feature_vector()"
<< "\n\t Invalid inputs were given to this function " << "\n\t Invalid inputs were given to this function "
<< "\n\t get_num_detection_templates(): " << get_num_detection_templates() << "\n\t get_num_detection_templates(): " << get_num_detection_templates()
...@@ -956,35 +957,34 @@ namespace dlib ...@@ -956,35 +957,34 @@ namespace dlib
<< "\n\t psi.size(): " << psi.size() << "\n\t psi.size(): " << psi.size()
<< "\n\t get_num_dimensions(): " << get_num_dimensions() << "\n\t get_num_dimensions(): " << get_num_dimensions()
<< "\n\t get_num_movable_components_per_detection_template(): " << get_num_movable_components_per_detection_template() << "\n\t get_num_movable_components_per_detection_template(): " << get_num_movable_components_per_detection_template()
<< "\n\t obj.movable_parts.size(): " << obj.movable_parts.size() << "\n\t obj.num_parts(): " << obj.num_parts()
<< "\n\t this: " << this << "\n\t this: " << this
); );
DLIB_ASSERT(all_parts_in_rect(obj), DLIB_ASSERT(all_parts_in_rect(obj),
"\t void scan_image_pyramid::get_feature_vector()" "\t void scan_image_pyramid::get_feature_vector()"
<< "\n\t Invalid inputs were given to this function " << "\n\t Invalid inputs were given to this function "
<< "\n\t obj.rect: " << obj.rect << "\n\t obj.get_rect(): " << obj.get_rect()
<< "\n\t this: " << this << "\n\t this: " << this
); );
const rectangle rect = obj.rect;
rectangle mapped_rect; rectangle mapped_rect;
detection_template best_template; detection_template best_template;
unsigned long best_level; unsigned long best_level;
rectangle object_box; rectangle object_box;
get_mapped_rect_and_metadata (feats.size(), rect, mapped_rect, best_template, object_box, best_level); get_mapped_rect_and_metadata (feats.size(), obj.get_rect(), mapped_rect, best_template, object_box, best_level);
Pyramid_type pyr; Pyramid_type pyr;
// put the movable rects at the places indicated by obj. // put the movable rects at the places indicated by obj.
std::vector<rectangle> rects = best_template.rects; std::vector<rectangle> rects = best_template.rects;
for (unsigned long i = 0; i < obj.movable_parts.size(); ++i) for (unsigned long i = 0; i < obj.num_parts(); ++i)
{ {
if (obj.movable_parts[i] != MOVABLE_PART_NOT_PRESENT) if (obj.part(i) != OBJECT_PART_NOT_PRESENT)
{ {
// map from the original image to scaled feature space. // map from the original image to scaled feature space.
point loc = feats[best_level].image_to_feat_space(pyr.point_down(obj.movable_parts[i], best_level)); point loc = feats[best_level].image_to_feat_space(pyr.point_down(obj.part(i), best_level));
// Make sure the movable part always stays within the object_box. // Make sure the movable part always stays within the object_box.
// Otherwise it would be at a place that the detect() function can never // Otherwise it would be at a place that the detect() function can never
// look. // look.
......
...@@ -398,7 +398,7 @@ namespace dlib ...@@ -398,7 +398,7 @@ namespace dlib
/*! /*!
requires requires
- all_parts_in_rect(obj) == true - all_parts_in_rect(obj) == true
- obj.movable_parts.size() == get_num_movable_components_per_detection_template() - obj.num_parts() == get_num_movable_components_per_detection_template()
- is_loaded_with_image() == true - is_loaded_with_image() == true
- get_num_detection_templates() > 0 - get_num_detection_templates() > 0
- psi.size() >= get_num_dimensions() - psi.size() >= get_num_dimensions()
...@@ -410,11 +410,11 @@ namespace dlib ...@@ -410,11 +410,11 @@ namespace dlib
detect() into the needed full_object_detection. detect() into the needed full_object_detection.
- Since scan_image_pyramid is a sliding window classifier system, not all - Since scan_image_pyramid is a sliding window classifier system, not all
possible rectangles can be output by detect(). So in the case where possible rectangles can be output by detect(). So in the case where
obj.rect could not arise from a call to detect(), this function will map obj.get_rect() could not arise from a call to detect(), this function
obj.rect to the nearest possible object box and then add the feature will map obj.get_rect() to the nearest possible object box and then add
vector for the mapped rectangle into #psi. the feature vector for the mapped rectangle into #psi.
- get_best_matching_rect(obj.rect) == the rectangle obj.rect gets mapped to - get_best_matching_rect(obj.get_rect()) == the rectangle obj.get_rect()
for feature extraction. gets mapped to for feature extraction.
!*/ !*/
full_object_detection get_full_object_detection ( full_object_detection get_full_object_detection (
...@@ -436,9 +436,10 @@ namespace dlib ...@@ -436,9 +436,10 @@ namespace dlib
Then the corresponding fully populated full_object_detection will be Then the corresponding fully populated full_object_detection will be
returned. returned.
- returns a full_object_detection, OBJ, such that: - returns a full_object_detection, OBJ, such that:
- OBJ.rect == rect - OBJ.get_rect() == rect
- OBJ.movable_parts.size() == get_num_movable_components_per_detection_template() - OBJ.num_parts() == get_num_movable_components_per_detection_template()
- OBJ.movable_parts == the locations of the movable parts inside this detection. - OBJ.part(i) == the location of the i-th movable part inside this detection,
or OBJECT_PART_NOT_PRESENT if the part was not found.
!*/ !*/
}; };
......
...@@ -48,7 +48,7 @@ namespace dlib ...@@ -48,7 +48,7 @@ namespace dlib
if (used[j]) if (used[j])
continue; continue;
const double overlap = truth_boxes[i].rect.intersect(boxes[j]).area() / (double)(truth_boxes[i].rect+boxes[j]).area(); const double overlap = truth_boxes[i].get_rect().intersect(boxes[j]).area() / (double)(truth_boxes[i].get_rect()+boxes[j]).area();
if (overlap > best_overlap) if (overlap > best_overlap)
{ {
best_overlap = overlap; best_overlap = overlap;
...@@ -76,16 +76,16 @@ namespace dlib ...@@ -76,16 +76,16 @@ namespace dlib
const matrix<double,1,2> test_object_detection_function ( const matrix<double,1,2> test_object_detection_function (
object_detector_type& detector, object_detector_type& detector,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_rects, const std::vector<std::vector<full_object_detection> >& truth_dets,
const double overlap_eps = 0.5 const double overlap_eps = 0.5
) )
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT( is_learning_problem(images,truth_rects) == true && DLIB_ASSERT( is_learning_problem(images,truth_dets) == true &&
0 < overlap_eps && overlap_eps <= 1, 0 < overlap_eps && overlap_eps <= 1,
"\t matrix test_object_detection_function()" "\t matrix test_object_detection_function()"
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
<< "\n\t is_learning_problem(images,truth_rects): " << is_learning_problem(images,truth_rects) << "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets)
<< "\n\t overlap_eps: "<< overlap_eps << "\n\t overlap_eps: "<< overlap_eps
); );
...@@ -100,8 +100,8 @@ namespace dlib ...@@ -100,8 +100,8 @@ namespace dlib
const std::vector<rectangle>& hits = detector(images[i]); const std::vector<rectangle>& hits = detector(images[i]);
total_hits += hits.size(); total_hits += hits.size();
correct_hits += impl::number_of_truth_hits(truth_rects[i], hits, overlap_eps); correct_hits += impl::number_of_truth_hits(truth_dets[i], hits, overlap_eps);
total_true_targets += truth_rects[i].size(); total_true_targets += truth_dets[i].size();
} }
...@@ -129,17 +129,17 @@ namespace dlib ...@@ -129,17 +129,17 @@ namespace dlib
const matrix<double,1,2> test_object_detection_function ( const matrix<double,1,2> test_object_detection_function (
object_detector_type& detector, object_detector_type& detector,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects, const std::vector<std::vector<rectangle> >& truth_dets,
const double overlap_eps = 0.5 const double overlap_eps = 0.5
) )
{ {
// convert into a list of regular rectangles. // convert into a list of regular rectangles.
std::vector<std::vector<full_object_detection> > rects(truth_rects.size()); std::vector<std::vector<full_object_detection> > rects(truth_dets.size());
for (unsigned long i = 0; i < truth_rects.size(); ++i) for (unsigned long i = 0; i < truth_dets.size(); ++i)
{ {
for (unsigned long j = 0; j < truth_rects[i].size(); ++j) for (unsigned long j = 0; j < truth_dets[i].size(); ++j)
{ {
rects[i].push_back(full_object_detection(truth_rects[i][j])); rects[i].push_back(full_object_detection(truth_dets[i][j]));
} }
} }
...@@ -188,18 +188,18 @@ namespace dlib ...@@ -188,18 +188,18 @@ namespace dlib
const matrix<double,1,2> cross_validate_object_detection_trainer ( const matrix<double,1,2> cross_validate_object_detection_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_object_detections, const std::vector<std::vector<full_object_detection> >& truth_dets,
const long folds, const long folds,
const double overlap_eps = 0.5 const double overlap_eps = 0.5
) )
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT( is_learning_problem(images,truth_object_detections) == true && DLIB_ASSERT( is_learning_problem(images,truth_dets) == true &&
0 < overlap_eps && overlap_eps <= 1 && 0 < overlap_eps && overlap_eps <= 1 &&
1 < folds && folds <= static_cast<long>(images.size()), 1 < folds && folds <= static_cast<long>(images.size()),
"\t matrix cross_validate_object_detection_trainer()" "\t matrix cross_validate_object_detection_trainer()"
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
<< "\n\t is_learning_problem(images,truth_object_detections): " << is_learning_problem(images,truth_object_detections) << "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets)
<< "\n\t overlap_eps: "<< overlap_eps << "\n\t overlap_eps: "<< overlap_eps
<< "\n\t folds: "<< folds << "\n\t folds: "<< folds
); );
...@@ -223,7 +223,7 @@ namespace dlib ...@@ -223,7 +223,7 @@ namespace dlib
std::vector<std::vector<full_object_detection> > training_rects; std::vector<std::vector<full_object_detection> > training_rects;
for (unsigned long i = 0; i < images.size()-test_size; ++i) for (unsigned long i = 0; i < images.size()-test_size; ++i)
{ {
training_rects.push_back(truth_object_detections[train_idx]); training_rects.push_back(truth_dets[train_idx]);
train_idx_set.push_back(train_idx); train_idx_set.push_back(train_idx);
train_idx = (train_idx+1)%images.size(); train_idx = (train_idx+1)%images.size();
} }
...@@ -236,8 +236,8 @@ namespace dlib ...@@ -236,8 +236,8 @@ namespace dlib
const std::vector<rectangle>& hits = detector(images[test_idx_set[i]]); const std::vector<rectangle>& hits = detector(images[test_idx_set[i]]);
total_hits += hits.size(); total_hits += hits.size();
correct_hits += impl::number_of_truth_hits(truth_object_detections[test_idx_set[i]], hits, overlap_eps); correct_hits += impl::number_of_truth_hits(truth_dets[test_idx_set[i]], hits, overlap_eps);
total_true_targets += truth_object_detections[test_idx_set[i]].size(); total_true_targets += truth_dets[test_idx_set[i]].size();
} }
} }
...@@ -268,18 +268,18 @@ namespace dlib ...@@ -268,18 +268,18 @@ namespace dlib
const matrix<double,1,2> cross_validate_object_detection_trainer ( const matrix<double,1,2> cross_validate_object_detection_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_object_detections, const std::vector<std::vector<rectangle> >& truth_dets,
const long folds, const long folds,
const double overlap_eps = 0.5 const double overlap_eps = 0.5
) )
{ {
// convert into a list of regular rectangles. // convert into a list of regular rectangles.
std::vector<std::vector<full_object_detection> > dets(truth_object_detections.size()); std::vector<std::vector<full_object_detection> > dets(truth_dets.size());
for (unsigned long i = 0; i < truth_object_detections.size(); ++i) for (unsigned long i = 0; i < truth_dets.size(); ++i)
{ {
for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j) for (unsigned long j = 0; j < truth_dets[i].size(); ++j)
{ {
dets[i].push_back(full_object_detection(truth_object_detections[i][j])); dets[i].push_back(full_object_detection(truth_dets[i][j]));
} }
} }
......
...@@ -20,12 +20,12 @@ namespace dlib ...@@ -20,12 +20,12 @@ namespace dlib
const matrix<double,1,2> test_object_detection_function ( const matrix<double,1,2> test_object_detection_function (
object_detector_type& detector, object_detector_type& detector,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_rects, const std::vector<std::vector<full_object_detection> >& truth_dets,
const double overlap_eps = 0.5 const double overlap_eps = 0.5
); );
/*! /*!
requires requires
- is_learning_problem(images,truth_rects) - is_learning_problem(images,truth_dets)
- 0 < overlap_eps <= 1 - 0 < overlap_eps <= 1
- object_detector_type == some kind of object detector function object - object_detector_type == some kind of object detector function object
(e.g. object_detector) (e.g. object_detector)
...@@ -34,7 +34,7 @@ namespace dlib ...@@ -34,7 +34,7 @@ namespace dlib
ensures ensures
- Tests the given detector against the supplied object detection problem - Tests the given detector against the supplied object detection problem
and returns the precision and recall. Note that the task is to predict, and returns the precision and recall. Note that the task is to predict,
for each images[i], the set of object locations given by truth_rects[i]. for each images[i], the set of object locations given by truth_dets[i].
- In particular, returns a matrix M such that: - In particular, returns a matrix M such that:
- M(0) == the precision of the detector object. This is a number - M(0) == the precision of the detector object. This is a number
in the range [0,1] which measures the fraction of detector outputs in the range [0,1] which measures the fraction of detector outputs
...@@ -44,7 +44,7 @@ namespace dlib ...@@ -44,7 +44,7 @@ namespace dlib
- M(1) == the recall of the detector object. This is a number in the - M(1) == the recall of the detector object. This is a number in the
range [0,1] which measure the fraction of targets found by the range [0,1] which measure the fraction of targets found by the
detector. A value of 1 means the detector found all the targets detector. A value of 1 means the detector found all the targets
in truth_rects while a value of 0 means the detector didn't locate in truth_dets while a value of 0 means the detector didn't locate
any of the targets. any of the targets.
- The rule for deciding if a detector output, D, matches a truth rectangle, - The rule for deciding if a detector output, D, matches a truth rectangle,
T, is the following: T, is the following:
...@@ -58,14 +58,14 @@ namespace dlib ...@@ -58,14 +58,14 @@ namespace dlib
const matrix<double,1,2> test_object_detection_function ( const matrix<double,1,2> test_object_detection_function (
object_detector_type& detector, object_detector_type& detector,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects, const std::vector<std::vector<rectangle> >& truth_dets,
const double overlap_eps = 0.5 const double overlap_eps = 0.5
); );
/*! /*!
requires requires
- all the requirements of the above test_object_detection_function() routine. - all the requirements of the above test_object_detection_function() routine.
ensures ensures
- converts all the rectangles in truth_rects into full_object_detection objects - converts all the rectangles in truth_dets into full_object_detection objects
via full_object_detection's rectangle constructor. Then invokes via full_object_detection's rectangle constructor. Then invokes
test_object_detection_function() on the full_object_detections and returns test_object_detection_function() on the full_object_detections and returns
the results. the results.
...@@ -80,19 +80,19 @@ namespace dlib ...@@ -80,19 +80,19 @@ namespace dlib
const matrix<double,1,2> cross_validate_object_detection_trainer ( const matrix<double,1,2> cross_validate_object_detection_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<full_object_detection> >& truth_rects, const std::vector<std::vector<full_object_detection> >& truth_dets,
const long folds, const long folds,
const double overlap_eps = 0.5 const double overlap_eps = 0.5
); );
/*! /*!
requires requires
- is_learning_problem(images,truth_rects) - is_learning_problem(images,truth_dets)
- 0 < overlap_eps <= 1 - 0 < overlap_eps <= 1
- 1 < folds <= images.size() - 1 < folds <= images.size()
- trainer_type == some kind of object detection trainer (e.g structural_object_detection_trainer) - trainer_type == some kind of object detection trainer (e.g structural_object_detection_trainer)
- image_array_type must be an implementation of dlib/array/array_kernel_abstract.h - image_array_type must be an implementation of dlib/array/array_kernel_abstract.h
and it must contain objects which can be accepted by detector(). and it must contain objects which can be accepted by detector().
- it is legal to call trainer.train(images, truth_rects) - it is legal to call trainer.train(images, truth_dets)
ensures ensures
- Performs k-fold cross-validation by using the given trainer to solve an - Performs k-fold cross-validation by using the given trainer to solve an
object detection problem for the given number of folds. Each fold is tested object detection problem for the given number of folds. Each fold is tested
...@@ -109,7 +109,7 @@ namespace dlib ...@@ -109,7 +109,7 @@ namespace dlib
const matrix<double,1,2> cross_validate_object_detection_trainer ( const matrix<double,1,2> cross_validate_object_detection_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const image_array_type& images, const image_array_type& images,
const std::vector<std::vector<rectangle> >& truth_rects, const std::vector<std::vector<rectangle> >& truth_dets,
const long folds, const long folds,
const double overlap_eps = 0.5 const double overlap_eps = 0.5
); );
...@@ -117,7 +117,7 @@ namespace dlib ...@@ -117,7 +117,7 @@ namespace dlib
requires requires
- all the requirements of the above cross_validate_object_detection_trainer() routine. - all the requirements of the above cross_validate_object_detection_trainer() routine.
ensures ensures
- converts all the rectangles in truth_rects into full_object_detection objects - converts all the rectangles in truth_dets into full_object_detection objects
via full_object_detection's rectangle constructor. Then invokes via full_object_detection's rectangle constructor. Then invokes
cross_validate_object_detection_trainer() on the full_object_detections and cross_validate_object_detection_trainer() on the full_object_detections and
returns the results. returns the results.
......
...@@ -262,12 +262,12 @@ namespace dlib ...@@ -262,12 +262,12 @@ namespace dlib
{ {
for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j) for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j)
{ {
DLIB_ASSERT(truth_object_detections[i][j].movable_parts.size() == get_scanner().get_num_movable_components_per_detection_template() && DLIB_ASSERT(truth_object_detections[i][j].num_parts() == get_scanner().get_num_movable_components_per_detection_template() &&
all_parts_in_rect(truth_object_detections[i][j]) == true, all_parts_in_rect(truth_object_detections[i][j]) == true,
"\t trained_function_type structural_object_detection_trainer::train()" "\t trained_function_type structural_object_detection_trainer::train()"
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
<< "\n\t truth_object_detections["<<i<<"]["<<j<<"].movable_parts.size(): " << << "\n\t truth_object_detections["<<i<<"]["<<j<<"].num_parts(): " <<
truth_object_detections[i][j].movable_parts.size() truth_object_detections[i][j].num_parts()
<< "\n\t get_scanner().get_num_movable_components_per_detection_template(): " << << "\n\t get_scanner().get_num_movable_components_per_detection_template(): " <<
get_scanner().get_num_movable_components_per_detection_template() get_scanner().get_num_movable_components_per_detection_template()
<< "\n\t all_parts_in_rect(truth_object_detections["<<i<<"]["<<j<<"]): " << all_parts_in_rect(truth_object_detections[i][j]) << "\n\t all_parts_in_rect(truth_object_detections["<<i<<"]["<<j<<"]): " << all_parts_in_rect(truth_object_detections[i][j])
...@@ -286,7 +286,7 @@ namespace dlib ...@@ -286,7 +286,7 @@ namespace dlib
mapped_rects[i].resize(truth_object_detections[i].size()); mapped_rects[i].resize(truth_object_detections[i].size());
for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j) for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j)
{ {
mapped_rects[i][j] = scanner.get_best_matching_rect(truth_object_detections[i][j].rect); mapped_rects[i][j] = scanner.get_best_matching_rect(truth_object_detections[i][j].get_rect());
} }
} }
......
...@@ -295,7 +295,7 @@ namespace dlib ...@@ -295,7 +295,7 @@ namespace dlib
- it must be valid to pass images[0] into the image_scanner_type::load() method. - it must be valid to pass images[0] into the image_scanner_type::load() method.
(also, image_array_type must be an implementation of dlib/array/array_kernel_abstract.h) (also, image_array_type must be an implementation of dlib/array/array_kernel_abstract.h)
- for all valid i, j: - for all valid i, j:
- truth_object_detections[i][j].movable_parts.size() == get_scanner().get_num_movable_components_per_detection_template() - truth_object_detections[i][j].num_parts() == get_scanner().get_num_movable_components_per_detection_template()
- all_parts_in_rect(truth_object_detections[i][j]) == true - all_parts_in_rect(truth_object_detections[i][j]) == true
ensures ensures
- Uses the structural_svm_object_detection_problem to train an object_detector - Uses the structural_svm_object_detection_problem to train an object_detector
......
...@@ -63,11 +63,11 @@ namespace dlib ...@@ -63,11 +63,11 @@ namespace dlib
{ {
for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j) for (unsigned long j = 0; j < truth_object_detections[i].size(); ++j)
{ {
DLIB_ASSERT(truth_object_detections[i][j].movable_parts.size() == scanner.get_num_movable_components_per_detection_template(), DLIB_ASSERT(truth_object_detections[i][j].num_parts() == scanner.get_num_movable_components_per_detection_template(),
"\t trained_function_type structural_object_detection_trainer::train()" "\t trained_function_type structural_object_detection_trainer::train()"
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
<< "\n\t truth_object_detections["<<i<<"]["<<j<<"].movable_parts.size(): " << << "\n\t truth_object_detections["<<i<<"]["<<j<<"].num_parts(): " <<
truth_object_detections[i][j].movable_parts.size() truth_object_detections[i][j].num_parts()
<< "\n\t scanner.get_num_movable_components_per_detection_template(): " << << "\n\t scanner.get_num_movable_components_per_detection_template(): " <<
scanner.get_num_movable_components_per_detection_template() scanner.get_num_movable_components_per_detection_template()
<< "\n\t all_parts_in_rect(truth_object_detections["<<i<<"]["<<j<<"]): " << all_parts_in_rect(truth_object_detections[i][j]) << "\n\t all_parts_in_rect(truth_object_detections["<<i<<"]["<<j<<"]): " << all_parts_in_rect(truth_object_detections[i][j])
...@@ -180,7 +180,7 @@ namespace dlib ...@@ -180,7 +180,7 @@ namespace dlib
psi = 0; psi = 0;
for (unsigned long i = 0; i < truth_object_detections[idx].size(); ++i) for (unsigned long i = 0; i < truth_object_detections[idx].size(); ++i)
{ {
mapped_rects.push_back(scanner.get_best_matching_rect(truth_object_detections[idx][i].rect)); mapped_rects.push_back(scanner.get_best_matching_rect(truth_object_detections[idx][i].get_rect()));
scanner.get_feature_vector(truth_object_detections[idx][i], psi); scanner.get_feature_vector(truth_object_detections[idx][i], psi);
} }
psi(scanner.get_num_dimensions()) = -1.0*truth_object_detections[idx].size(); psi(scanner.get_num_dimensions()) = -1.0*truth_object_detections[idx].size();
...@@ -225,8 +225,8 @@ namespace dlib ...@@ -225,8 +225,8 @@ namespace dlib
// truth rectangles. // truth rectangles.
for (unsigned long i = 0; i < mapped_rects.size(); ++i) for (unsigned long i = 0; i < mapped_rects.size(); ++i)
{ {
const double area = (truth_object_detections[idx][i].rect.intersect(mapped_rects[i])).area(); const double area = (truth_object_detections[idx][i].get_rect().intersect(mapped_rects[i])).area();
const double total_area = (truth_object_detections[idx][i].rect + mapped_rects[i]).area(); const double total_area = (truth_object_detections[idx][i].get_rect() + mapped_rects[i]).area();
if (area/total_area <= match_eps) if (area/total_area <= match_eps)
{ {
using namespace std; using namespace std;
...@@ -249,9 +249,9 @@ namespace dlib ...@@ -249,9 +249,9 @@ namespace dlib
sout << "image index "<< idx << endl; sout << "image index "<< idx << endl;
sout << "match_eps: "<< match_eps << endl; sout << "match_eps: "<< match_eps << endl;
sout << "best possible match: "<< area/total_area << endl; sout << "best possible match: "<< area/total_area << endl;
sout << "truth rect: "<< truth_object_detections[idx][i].rect << endl; sout << "truth rect: "<< truth_object_detections[idx][i].get_rect() << endl;
sout << "truth rect width/height: "<< truth_object_detections[idx][i].rect.width()/(double)truth_object_detections[idx][i].rect.height() << endl; sout << "truth rect width/height: "<< truth_object_detections[idx][i].get_rect().width()/(double)truth_object_detections[idx][i].get_rect().height() << endl;
sout << "truth rect area: "<< truth_object_detections[idx][i].rect.area() << endl; sout << "truth rect area: "<< truth_object_detections[idx][i].get_rect().area() << endl;
sout << "nearest detection template rect: "<< mapped_rects[i] << endl; sout << "nearest detection template rect: "<< mapped_rects[i] << endl;
sout << "nearest detection template rect width/height: "<< mapped_rects[i].width()/(double)mapped_rects[i].height() << endl; sout << "nearest detection template rect width/height: "<< mapped_rects[i].width()/(double)mapped_rects[i].height() << endl;
sout << "nearest detection template rect area: "<< mapped_rects[i].area() << endl; sout << "nearest detection template rect area: "<< mapped_rects[i].area() << endl;
...@@ -422,10 +422,10 @@ namespace dlib ...@@ -422,10 +422,10 @@ namespace dlib
for (unsigned long i = 0; i < boxes.size(); ++i) for (unsigned long i = 0; i < boxes.size(); ++i)
{ {
const unsigned long area = rect.intersect(boxes[i].rect).area(); const unsigned long area = rect.intersect(boxes[i].get_rect()).area();
if (area != 0) if (area != 0)
{ {
const double new_match = area / static_cast<double>((rect + boxes[i].rect).area()); const double new_match = area / static_cast<double>((rect + boxes[i].get_rect()).area());
if (new_match > match) if (new_match > match)
{ {
match = new_match; match = new_match;
......
...@@ -91,17 +91,20 @@ namespace dlib ...@@ -91,17 +91,20 @@ namespace dlib
- scanner.get_num_detection_templates() > 0 - scanner.get_num_detection_templates() > 0
- scanner.load(images[0]) must be a valid expression. - scanner.load(images[0]) must be a valid expression.
- for all valid i, j: - for all valid i, j:
- truth_object_detections[i][j].movable_rects.size() == scanner.get_num_movable_components_per_detection_template() - truth_object_detections[i][j].num_parts() == scanner.get_num_movable_components_per_detection_template()
- all_parts_in_rect(truth_object_detections[i][j]) == true - all_parts_in_rect(truth_object_detections[i][j]) == true
ensures ensures
- This object attempts to learn a mapping from the given images to the - This object attempts to learn a mapping from the given images to the
object locations given in truth_object_detections. In particular, it attempts to object locations given in truth_object_detections. In particular, it
learn to predict truth_object_detections[i] based on images[i]. attempts to learn to predict truth_object_detections[i] based on
Or in other words, this object can be used to learn a parameter vector, w, such that images[i]. Or in other words, this object can be used to learn a
an object_detector declared as: parameter vector, w, such that an object_detector declared as:
object_detector<image_scanner_type,overlap_tester_type> detector(scanner,overlap_tester,w) object_detector<image_scanner_type,overlap_tester_type> detector(scanner,overlap_tester,w)
results in a detector object which attempts to compute the following mapping: results in a detector object which attempts to compute the locations of
truth_object_detections[i].rect == detector(images[i]) all the objects in truth_object_detections. So if you called
detector(images[i]) you would hopefully get a list of rectangles back
that had truth_object_detections[i].size() elements and contained exactly
the rectangles indicated by truth_object_detections[i].
- #get_match_eps() == 0.5 - #get_match_eps() == 0.5
- This object will use num_threads threads during the optimization - This object will use num_threads threads during the optimization
procedure. You should set this parameter equal to the number of procedure. You should set this parameter equal to the number of
......
...@@ -275,48 +275,61 @@ namespace ...@@ -275,48 +275,61 @@ namespace
// Now make some squares and draw them onto our black images. All the // Now make some squares and draw them onto our black images. All the
// squares will be 70 pixels wide and tall. // squares will be 70 pixels wide and tall.
const int shrink = 0; const int shrink = 0;
std::vector<full_object_detection> temp; std::vector<full_object_detection> temp;
temp.push_back(full_object_detection(centered_rect(point(100,100), 70,71)));
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tr_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).bl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).br_corner());
fill_rect(images[0],temp.back().rect,255); // Paint the square white
temp.push_back(full_object_detection(centered_rect(point(200,300), 70,71)));
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tr_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).bl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).br_corner());
fill_rect(images[0],temp.back().rect,255); // Paint the square white
object_locations.push_back(temp); rectangle rect = centered_rect(point(100,100), 70,71);
std::vector<point> movable_parts;
movable_parts.push_back(shrink_rect(rect,shrink).tl_corner());
movable_parts.push_back(shrink_rect(rect,shrink).tr_corner());
movable_parts.push_back(shrink_rect(rect,shrink).bl_corner());
movable_parts.push_back(shrink_rect(rect,shrink).br_corner());
temp.push_back(full_object_detection(rect, movable_parts));
fill_rect(images[0],rect,255); // Paint the square white
rect = centered_rect(point(200,200), 70,71);
movable_parts.clear();
movable_parts.push_back(shrink_rect(rect,shrink).tl_corner());
movable_parts.push_back(shrink_rect(rect,shrink).tr_corner());
movable_parts.push_back(shrink_rect(rect,shrink).bl_corner());
movable_parts.push_back(shrink_rect(rect,shrink).br_corner());
temp.push_back(full_object_detection(rect, movable_parts));
fill_rect(images[0],rect,255); // Paint the square white
object_locations.push_back(temp);
// ------------------------------------
temp.clear(); temp.clear();
temp.push_back(full_object_detection(centered_rect(point(140,200), 70,71)));
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tr_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).bl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).br_corner());
fill_rect(images[1],temp.back().rect,255); // Paint the square white
temp.push_back(full_object_detection(centered_rect(point(303,200), 70,71)));
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tr_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).bl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).br_corner());
fill_rect(images[1],temp.back().rect,255); // Paint the square white
object_locations.push_back(temp); rect = centered_rect(point(140,200), 70,71);
movable_parts.clear();
movable_parts.push_back(shrink_rect(rect,shrink).tl_corner());
movable_parts.push_back(shrink_rect(rect,shrink).tr_corner());
movable_parts.push_back(shrink_rect(rect,shrink).bl_corner());
movable_parts.push_back(shrink_rect(rect,shrink).br_corner());
temp.push_back(full_object_detection(rect, movable_parts));
fill_rect(images[1],rect,255); // Paint the square white
rect = centered_rect(point(303,200), 70,71);
movable_parts.clear();
movable_parts.push_back(shrink_rect(rect,shrink).tl_corner());
movable_parts.push_back(shrink_rect(rect,shrink).tr_corner());
movable_parts.push_back(shrink_rect(rect,shrink).bl_corner());
movable_parts.push_back(shrink_rect(rect,shrink).br_corner());
temp.push_back(full_object_detection(rect, movable_parts));
fill_rect(images[1],rect,255); // Paint the square white
object_locations.push_back(temp);
// ------------------------------------
temp.clear(); temp.clear();
temp.push_back(full_object_detection(centered_rect(point(123,121), 70,71)));
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tl_corner()); rect = centered_rect(point(123,121), 70,71);
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).tr_corner()); movable_parts.clear();
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).bl_corner()); movable_parts.push_back(shrink_rect(rect,shrink).tl_corner());
temp.back().movable_parts.push_back(shrink_rect(temp.back().rect,shrink).br_corner()); movable_parts.push_back(shrink_rect(rect,shrink).tr_corner());
fill_rect(images[2],temp.back().rect,255); // Paint the square white movable_parts.push_back(shrink_rect(rect,shrink).bl_corner());
movable_parts.push_back(shrink_rect(rect,shrink).br_corner());
temp.push_back(full_object_detection(rect, movable_parts));
fill_rect(images[2],rect,255); // Paint the square white
object_locations.push_back(temp); object_locations.push_back(temp);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment