Commit 770781ac authored by Davis King's avatar Davis King

Added a label field to mmod_rect and updated code that uses this object to use

the label in the way you would expect.  In particular, loss_mmod_ is now a
multi-class loss and therefore capable of learning a detector that can output
detections with different labels.
parent b56f73ce
......@@ -230,6 +230,7 @@ namespace dlib
rects.push_back(mmod_rect(data.images[i].boxes[j].rect));
min_rect_size = std::min<double>(min_rect_size, rects.back().rect.area());
}
rects.back().label = data.images[i].boxes[j].label;
}
}
......
......@@ -220,10 +220,11 @@ namespace dlib
dlib/image_processing/generic_image.h.
ensures
- This function has essentially the same behavior as the above
load_image_dataset() routines, except here we out put to a vector of
load_image_dataset() routines, except here we output to a vector of
mmod_rects instead of rectangles. In this case, both ignore and non-ignore
rectangles go into object_locations since mmod_rect has an ignore boolean
field that records the ignored/non-ignored state of each rectangle.
field that records the ignored/non-ignored state of each rectangle. We also store
a each box's string label into the mmod_rect::label field as well.
!*/
// ----------------------------------------------------------------------------------------
......
This diff is collapsed.
......@@ -374,25 +374,29 @@ namespace dlib
public:
struct detector_window_size
struct detector_window_details
{
detector_window_size() = default;
detector_window_size(unsigned long w, unsigned long h) : width(w), height(h) {}
detector_window_details() = default;
detector_window_details(unsigned long w, unsigned long h) : width(w), height(h) {}
detector_window_details(unsigned long w, unsigned long h, const std::string& l) : width(w), height(h), label(l) {}
unsigned long width = 0;
unsigned long height = 0;
std::string label;
friend inline void serialize(const detector_window_size& item, std::ostream& out);
friend inline void deserialize(detector_window_size& item, std::istream& in);
friend inline void serialize(const detector_window_details& item, std::ostream& out);
friend inline void deserialize(detector_window_details& item, std::istream& in);
};
mmod_options() = default;
// This kind of object detector is a sliding window detector. The detector_windows
// field determines how many sliding windows we will use and what the shape of each
// window is. Since you will usually use the MMOD loss with an image pyramid, the
// detector sizes also determine the size of the smallest object you can detect.
std::vector<detector_window_size> detector_windows;
// window is. It also determines the output label applied to each detection
// identified by each window. Since you will usually use the MMOD loss with an
// image pyramid, the detector sizes also determine the size of the smallest object
// you can detect.
std::vector<detector_window_details> detector_windows;
// These parameters control how we penalize different kinds of mistakes. See
// Max-Margin Object Detection by Davis E. King (http://arxiv.org/abs/1502.00046)
......@@ -439,6 +443,10 @@ namespace dlib
each box in boxes could potentially be detected by one of the
detector windows. This essentially comes down to picking detector
windows with aspect ratios similar to the aspect ratios in boxes.
Note that we also make sure that each box can be detected by a window
with the same label. For example, if all the boxes had the same
aspect ratio but there were 4 different labels used in boxes then
there would be 4 resulting detector windows, one for each label.
- The longest edge of each detector window is target_size pixels in
length, unless the window's shortest side would be less than
min_target_size pixels in length. In this case the shortest side
......
......@@ -6,9 +6,22 @@
#include "../svm/cross_validate_object_detection_trainer_abstract.h"
#include "../svm/cross_validate_object_detection_trainer.h"
#include "layers.h"
#include <set>
namespace dlib
{
namespace impl
{
inline std::set<std::string> get_labels (
const std::vector<mmod_rect>& rects
)
{
std::set<std::string> labels;
for (auto& rr : rects)
labels.insert(rr.label);
return labels;
}
}
template <
typename SUBNET,
......@@ -49,22 +62,32 @@ namespace dlib
detector.loss_details().to_label(temp, detector.subnet(), &hits, adjust_threshold);
std::vector<full_object_detection> truth_boxes;
std::vector<rectangle> ignore;
std::vector<std::pair<double,rectangle>> boxes;
// copy hits and truth_dets into the above three objects
for (auto&& b : truth_dets[i])
for (auto& label : impl::get_labels(truth_dets[i]))
{
if (b.ignore)
ignore.push_back(b);
else
truth_boxes.push_back(full_object_detection(b.rect));
std::vector<full_object_detection> truth_boxes;
std::vector<rectangle> ignore;
std::vector<std::pair<double,rectangle>> boxes;
// copy hits and truth_dets into the above three objects
for (auto&& b : truth_dets[i])
{
if (b.ignore)
{
ignore.push_back(b);
}
else if (b.label == label)
{
truth_boxes.push_back(full_object_detection(b.rect));
++total_true_targets;
}
}
for (auto&& b : hits)
{
if (b.label == label)
boxes.push_back(std::make_pair(b.detection_confidence, b.rect));
}
correct_hits += impl::number_of_truth_hits(truth_boxes, ignore, boxes, overlap_tester, all_dets, missing_detections, overlaps_ignore_tester);
}
for (auto&& b : hits)
boxes.push_back(std::make_pair(b.detection_confidence, b.rect));
correct_hits += impl::number_of_truth_hits(truth_boxes, ignore, boxes, overlap_tester, all_dets, missing_detections, overlaps_ignore_tester);
total_true_targets += truth_boxes.size();
}
std::sort(all_dets.rbegin(), all_dets.rend());
......
......@@ -134,17 +134,20 @@ namespace dlib
mmod_rect() = default;
mmod_rect(const rectangle& r) : rect(r) {}
mmod_rect(const rectangle& r, double score) : rect(r),detection_confidence(score) {}
mmod_rect(const rectangle& r, double score, const std::string& label) : rect(r),detection_confidence(score), label(label) {}
rectangle rect;
double detection_confidence = 0;
bool ignore = false;
std::string label;
operator rectangle() const { return rect; }
bool operator == (const mmod_rect& rhs) const
{
return rect == rhs.rect
&& detection_confidence == rhs.detection_confidence
&& ignore == rhs.ignore;
&& ignore == rhs.ignore
&& label == rhs.label;
}
};
......@@ -157,22 +160,27 @@ namespace dlib
inline void serialize(const mmod_rect& item, std::ostream& out)
{
int version = 1;
int version = 2;
serialize(version, out);
serialize(item.rect, out);
serialize(item.detection_confidence, out);
serialize(item.ignore, out);
serialize(item.label, out);
}
inline void deserialize(mmod_rect& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
if (version != 1 && version != 2)
throw serialization_error("Unexpected version found while deserializing dlib::mmod_rect");
deserialize(item.rect, in);
deserialize(item.detection_confidence, in);
deserialize(item.ignore, in);
if (version == 2)
deserialize(item.label, in);
else
item.label = "";
}
// ----------------------------------------------------------------------------------------
......
......@@ -159,12 +159,21 @@ namespace dlib
mmod_rect() = default;
mmod_rect(const rectangle& r) : rect(r) {}
mmod_rect(const rectangle& r, double score) : rect(r),detection_confidence(score) {}
mmod_rect(const rectangle& r, double score, const std::string& label) : rect(r),detection_confidence(score),label(label) {}
rectangle rect;
double detection_confidence = 0;
bool ignore = false;
std::string label;
operator rectangle() const { return rect; }
bool operator == (const mmod_rect& rhs) const;
/*!
ensures
- returns true if and only if all the elements of this object compare equal
to the corresponding elements of rhs.
!*/
};
mmod_rect ignored_mmod_rect(
......@@ -176,6 +185,7 @@ namespace dlib
- R.rect == r
- R.ignore == true
- R.detection_confidence == 0
- R.label == ""
!*/
void serialize(const mmod_rect& item, std::ostream& out);
......
......@@ -155,12 +155,13 @@ namespace dlib
object_detector's except it runs on CNNs that use loss_mmod.
- Tests the given detector against the supplied object detection problem and
returns the precision, recall, and average precision. Note that the task is
to predict, for each images[i], the set of object locations given by
truth_dets[i]. Additionally, any detections on image[i] that match a box in
truth_dets[i] that are marked ignore are ignored. That is, detections
matching an ignore box do not count as a false alarm and similarly if any
to predict, for each images[i], the set of object locations, and their
corresponding labels, given by truth_dets[i]. Additionally, any detections
on image[i] that match a box in truth_dets[i] that are marked ignore are
ignored. That is, detections matching an ignore box, regardless of the
ignore box's label, do not count as a false alarm and similarly if any
ignored box in truth_dets goes undetected it does not count as a missed
detection. To test if a box overlaps an ignore box, we use overlaps_ignore_tester.
detection. To test if a box overlaps an ignore box, we use overlaps_ignore_tester.
- In particular, returns a matrix M such that:
- M(0) == the precision of the detector object. This is a number
in the range [0,1] which measures the fraction of detector outputs
......@@ -178,8 +179,8 @@ namespace dlib
ordering them in descending order of their detection scores. Then we use
the average_precision() routine to score the ranked listing and store the
output into M(2).
- This function considers a detector output D to match a rectangle T if and
only if overlap_tester(T,D) returns true.
- This function considers a detector output D to match a truth rectangle T if
and only if overlap_tester(T,D) returns true and the labels are identical strings.
- Note that you can use the adjust_threshold argument to raise or lower the
detection threshold. This value is passed into the identically named
argument to the detector object and therefore influences the number of
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment