Commit 770781ac authored by Davis King's avatar Davis King

Added a label field to mmod_rect and updated code that uses this object to use

the label in the way you would expect.  In particular, loss_mmod_ is now a
multi-class loss and therefore capable of learning a detector that can output
detections with different labels.
parent b56f73ce
......@@ -230,6 +230,7 @@ namespace dlib
rects.push_back(mmod_rect(data.images[i].boxes[j].rect));
min_rect_size = std::min<double>(min_rect_size, rects.back().rect.area());
}
rects.back().label = data.images[i].boxes[j].label;
}
}
......
......@@ -220,10 +220,11 @@ namespace dlib
dlib/image_processing/generic_image.h.
ensures
- This function has essentially the same behavior as the above
load_image_dataset() routines, except here we out put to a vector of
load_image_dataset() routines, except here we output to a vector of
mmod_rects instead of rectangles. In this case, both ignore and non-ignore
rectangles go into object_locations since mmod_rect has an ignore boolean
field that records the ignored/non-ignored state of each rectangle.
field that records the ignored/non-ignored state of each rectangle. We also store
a each box's string label into the mmod_rect::label field as well.
!*/
// ----------------------------------------------------------------------------------------
......
......@@ -364,37 +364,42 @@ namespace dlib
{
public:
struct detector_window_size
struct detector_window_details
{
detector_window_size() = default;
detector_window_size(unsigned long w, unsigned long h) : width(w), height(h) {}
detector_window_details() = default;
detector_window_details(unsigned long w, unsigned long h) : width(w), height(h) {}
detector_window_details(unsigned long w, unsigned long h, const std::string& l) : width(w), height(h), label(l) {}
unsigned long width = 0;
unsigned long height = 0;
std::string label;
friend inline void serialize(const detector_window_size& item, std::ostream& out)
friend inline void serialize(const detector_window_details& item, std::ostream& out)
{
int version = 1;
int version = 2;
serialize(version, out);
serialize(item.width, out);
serialize(item.height, out);
serialize(item.label, out);
}
friend inline void deserialize(detector_window_size& item, std::istream& in)
friend inline void deserialize(detector_window_details& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::mmod_options::detector_window_size");
if (version != 1 && version != 2)
throw serialization_error("Unexpected version found while deserializing dlib::mmod_options::detector_window_details");
deserialize(item.width, in);
deserialize(item.height, in);
if (version == 2)
deserialize(item.label, in);
}
};
mmod_options() = default;
std::vector<detector_window_size> detector_windows;
std::vector<detector_window_details> detector_windows;
double loss_per_false_alarm = 1;
double loss_per_missed_target = 1;
double truth_match_iou_threshold = 0.5;
......@@ -412,7 +417,9 @@ namespace dlib
DLIB_CASSERT(0.5 < min_detector_window_overlap_iou && min_detector_window_overlap_iou < 1);
// Figure out what detector windows we will need.
for (auto ratio : find_covering_aspect_ratios(boxes, test_box_overlap(min_detector_window_overlap_iou)))
for (auto& label : get_labels(boxes))
{
for (auto ratio : find_covering_aspect_ratios(boxes, test_box_overlap(min_detector_window_overlap_iou), label))
{
double detector_width;
double detector_height;
......@@ -437,9 +444,10 @@ namespace dlib
}
}
detector_window_size p((unsigned long)std::round(detector_width), (unsigned long)std::round(detector_height));
detector_window_details p((unsigned long)std::round(detector_width), (unsigned long)std::round(detector_height), label);
detector_windows.push_back(p);
}
}
DLIB_CASSERT(detector_windows.size() != 0, "You can't call mmod_options's constructor with a set of boxes that is empty (or only contains ignored boxes).");
......@@ -542,9 +550,23 @@ namespace dlib
return exemplars;
}
static std::set<std::string> get_labels (
const std::vector<std::vector<mmod_rect>>& rects
)
{
std::set<std::string> labels;
for (auto& rr : rects)
{
for (auto& r : rr)
labels.insert(r.label);
}
return labels;
}
static std::vector<double> find_covering_aspect_ratios (
const std::vector<std::vector<mmod_rect>>& rects,
const test_box_overlap& overlaps
const test_box_overlap& overlaps,
const std::string& label
)
{
std::vector<rectangle> boxes;
......@@ -555,7 +577,7 @@ namespace dlib
{
for (auto&& b : bb)
{
if (!b.ignore)
if (!b.ignore && b.label == label)
boxes.push_back(move_rect(set_rect_area(b.rect,400*400), point(0,0)));
}
}
......@@ -593,7 +615,7 @@ namespace dlib
unsigned long height;
deserialize(width, in);
deserialize(height, in);
item.detector_windows = {mmod_options::detector_window_size(width, height)};
item.detector_windows = {mmod_options::detector_window_details(width, height)};
}
else
{
......@@ -612,21 +634,23 @@ namespace dlib
{
struct intermediate_detection
{
intermediate_detection() : detection_confidence(0), tensor_offset(0) {}
intermediate_detection() = default;
intermediate_detection(
rectangle rect_
) : rect(rect_), detection_confidence(0), tensor_offset(0) {}
) : rect(rect_) {}
intermediate_detection(
rectangle rect_,
double detection_confidence_,
size_t tensor_offset_
) : rect(rect_), detection_confidence(detection_confidence_), tensor_offset(tensor_offset_) {}
size_t tensor_offset_,
long channel
) : rect(rect_), detection_confidence(detection_confidence_), tensor_offset(tensor_offset_), tensor_channel(channel) {}
rectangle rect;
double detection_confidence;
size_t tensor_offset;
double detection_confidence = 0;
size_t tensor_offset = 0;
long tensor_channel = 0;
bool operator<(const intermediate_detection& item) const { return detection_confidence < item.detection_confidence; }
};
......@@ -672,7 +696,9 @@ namespace dlib
if (overlaps_any_box_nms(final_dets, dets_accum[i].rect))
continue;
final_dets.push_back(mmod_rect(dets_accum[i].rect, dets_accum[i].detection_confidence));
final_dets.push_back(mmod_rect(dets_accum[i].rect,
dets_accum[i].detection_confidence,
options.detector_windows[dets_accum[i].tensor_channel].label));
}
*iter++ = std::move(final_dets);
......@@ -728,7 +754,7 @@ namespace dlib
{
size_t k;
point p;
if(image_rect_to_feat_coord(p, input_tensor, x, sub, k))
if(image_rect_to_feat_coord(p, input_tensor, x, x.label, sub, k))
{
// Ignore boxes that can't be detected by the CNN.
loss -= 1;
......@@ -761,7 +787,9 @@ namespace dlib
if (overlaps_any_box_nms(final_dets, dets[i].rect))
continue;
const std::pair<double,unsigned int> hittruth = find_best_match(*truth, dets[i].rect);
const auto& det_label = options.detector_windows[dets[i].tensor_channel].label;
const std::pair<double,unsigned int> hittruth = find_best_match(*truth, dets[i].rect, det_label);
final_dets.push_back(dets[i].rect);
......@@ -807,7 +835,8 @@ namespace dlib
std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << (*truth)[i].rect;
std::cout << " that is suppressed by non-max-suppression ";
std::cout << "because it is overlapped by another truth rectangle located at " << best_matching_truth_box
<< " (IoU:"<< box_intersection_over_union(best_matching_truth_box,(*truth)[i]) <<")." << std::endl;
<< " (IoU:"<< box_intersection_over_union(best_matching_truth_box,(*truth)[i]) <<", Percent covered:"
<< box_percent_covered(best_matching_truth_box,(*truth)[i]) << ")." << std::endl;
}
}
}
......@@ -825,7 +854,9 @@ namespace dlib
if (overlaps_any_box_nms(final_dets, dets[i].rect))
continue;
const std::pair<double,unsigned int> hittruth = find_best_match(*truth, dets[i].rect);
const auto& det_label = options.detector_windows[dets[i].tensor_channel].label;
const std::pair<double,unsigned int> hittruth = find_best_match(*truth, dets[i].rect, det_label);
const double truth_match = hittruth.first;
if (truth_match > options.truth_match_iou_threshold)
......@@ -949,7 +980,7 @@ namespace dlib
drectangle rect = centered_drect(p, options.detector_windows[k].width, options.detector_windows[k].height);
rect = input_layer(net).tensor_space_to_image_space(input_tensor,rect);
dets_accum.push_back(intermediate_detection(rect, score, (k*output_tensor.nr() + r)*output_tensor.nc() + c));
dets_accum.push_back(intermediate_detection(rect, score, (k*output_tensor.nr() + r)*output_tensor.nc() + c, k));
}
}
}
......@@ -958,7 +989,8 @@ namespace dlib
}
size_t find_best_detection_window (
rectangle rect
rectangle rect,
const std::string& label
) const
{
rect = move_rect(set_rect_area(rect, 400*400), point(0,0));
......@@ -969,6 +1001,8 @@ namespace dlib
double best_ratio_diff = -std::numeric_limits<double>::infinity();
for (size_t i = 0; i < options.detector_windows.size(); ++i)
{
if (options.detector_windows[i].label != label)
continue;
rectangle det_window = centered_rect(point(0,0), options.detector_windows[i].width, options.detector_windows[i].height);
det_window = move_rect(set_rect_area(det_window, 400*400), point(0,0));
......@@ -987,6 +1021,7 @@ namespace dlib
point& tensor_p,
const tensor& input_tensor,
const rectangle& rect,
const std::string& label,
const net_type& net,
size_t& det_idx
) const
......@@ -1000,7 +1035,7 @@ namespace dlib
throw impossible_labeling_error(sout.str());
}
det_idx = find_best_detection_window(rect);
det_idx = find_best_detection_window(rect,label);
// Compute the scale we need to be at to get from rect to our detection window.
// Note that we compute the scale as the max of two numbers. It doesn't
......@@ -1065,10 +1100,26 @@ namespace dlib
std::pair<double,unsigned int> find_best_match(
const std::vector<mmod_rect>& boxes,
const rectangle& rect
const rectangle& rect,
const std::string& label
) const
{
return find_best_match(boxes, rect, boxes.size());
double match = 0;
unsigned int best_idx = 0;
for (unsigned long i = 0; i < boxes.size(); ++i)
{
if (boxes[i].ignore || boxes[i].label != label)
continue;
const double new_match = box_intersection_over_union(rect, boxes[i]);
if (new_match > match)
{
match = new_match;
best_idx = i;
}
}
return std::make_pair(match,best_idx);
}
std::pair<double,unsigned int> find_best_match(
......
......@@ -374,25 +374,29 @@ namespace dlib
public:
struct detector_window_size
struct detector_window_details
{
detector_window_size() = default;
detector_window_size(unsigned long w, unsigned long h) : width(w), height(h) {}
detector_window_details() = default;
detector_window_details(unsigned long w, unsigned long h) : width(w), height(h) {}
detector_window_details(unsigned long w, unsigned long h, const std::string& l) : width(w), height(h), label(l) {}
unsigned long width = 0;
unsigned long height = 0;
std::string label;
friend inline void serialize(const detector_window_size& item, std::ostream& out);
friend inline void deserialize(detector_window_size& item, std::istream& in);
friend inline void serialize(const detector_window_details& item, std::ostream& out);
friend inline void deserialize(detector_window_details& item, std::istream& in);
};
mmod_options() = default;
// This kind of object detector is a sliding window detector. The detector_windows
// field determines how many sliding windows we will use and what the shape of each
// window is. Since you will usually use the MMOD loss with an image pyramid, the
// detector sizes also determine the size of the smallest object you can detect.
std::vector<detector_window_size> detector_windows;
// window is. It also determines the output label applied to each detection
// identified by each window. Since you will usually use the MMOD loss with an
// image pyramid, the detector sizes also determine the size of the smallest object
// you can detect.
std::vector<detector_window_details> detector_windows;
// These parameters control how we penalize different kinds of mistakes. See
// Max-Margin Object Detection by Davis E. King (http://arxiv.org/abs/1502.00046)
......@@ -439,6 +443,10 @@ namespace dlib
each box in boxes could potentially be detected by one of the
detector windows. This essentially comes down to picking detector
windows with aspect ratios similar to the aspect ratios in boxes.
Note that we also make sure that each box can be detected by a window
with the same label. For example, if all the boxes had the same
aspect ratio but there were 4 different labels used in boxes then
there would be 4 resulting detector windows, one for each label.
- The longest edge of each detector window is target_size pixels in
length, unless the window's shortest side would be less than
min_target_size pixels in length. In this case the shortest side
......
......@@ -6,9 +6,22 @@
#include "../svm/cross_validate_object_detection_trainer_abstract.h"
#include "../svm/cross_validate_object_detection_trainer.h"
#include "layers.h"
#include <set>
namespace dlib
{
namespace impl
{
inline std::set<std::string> get_labels (
const std::vector<mmod_rect>& rects
)
{
std::set<std::string> labels;
for (auto& rr : rects)
labels.insert(rr.label);
return labels;
}
}
template <
typename SUBNET,
......@@ -49,6 +62,8 @@ namespace dlib
detector.loss_details().to_label(temp, detector.subnet(), &hits, adjust_threshold);
for (auto& label : impl::get_labels(truth_dets[i]))
{
std::vector<full_object_detection> truth_boxes;
std::vector<rectangle> ignore;
std::vector<std::pair<double,rectangle>> boxes;
......@@ -56,15 +71,23 @@ namespace dlib
for (auto&& b : truth_dets[i])
{
if (b.ignore)
{
ignore.push_back(b);
else
}
else if (b.label == label)
{
truth_boxes.push_back(full_object_detection(b.rect));
++total_true_targets;
}
}
for (auto&& b : hits)
{
if (b.label == label)
boxes.push_back(std::make_pair(b.detection_confidence, b.rect));
}
correct_hits += impl::number_of_truth_hits(truth_boxes, ignore, boxes, overlap_tester, all_dets, missing_detections, overlaps_ignore_tester);
total_true_targets += truth_boxes.size();
}
}
std::sort(all_dets.rbegin(), all_dets.rend());
......
......@@ -134,17 +134,20 @@ namespace dlib
mmod_rect() = default;
mmod_rect(const rectangle& r) : rect(r) {}
mmod_rect(const rectangle& r, double score) : rect(r),detection_confidence(score) {}
mmod_rect(const rectangle& r, double score, const std::string& label) : rect(r),detection_confidence(score), label(label) {}
rectangle rect;
double detection_confidence = 0;
bool ignore = false;
std::string label;
operator rectangle() const { return rect; }
bool operator == (const mmod_rect& rhs) const
{
return rect == rhs.rect
&& detection_confidence == rhs.detection_confidence
&& ignore == rhs.ignore;
&& ignore == rhs.ignore
&& label == rhs.label;
}
};
......@@ -157,22 +160,27 @@ namespace dlib
inline void serialize(const mmod_rect& item, std::ostream& out)
{
int version = 1;
int version = 2;
serialize(version, out);
serialize(item.rect, out);
serialize(item.detection_confidence, out);
serialize(item.ignore, out);
serialize(item.label, out);
}
inline void deserialize(mmod_rect& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
if (version != 1 && version != 2)
throw serialization_error("Unexpected version found while deserializing dlib::mmod_rect");
deserialize(item.rect, in);
deserialize(item.detection_confidence, in);
deserialize(item.ignore, in);
if (version == 2)
deserialize(item.label, in);
else
item.label = "";
}
// ----------------------------------------------------------------------------------------
......
......@@ -159,12 +159,21 @@ namespace dlib
mmod_rect() = default;
mmod_rect(const rectangle& r) : rect(r) {}
mmod_rect(const rectangle& r, double score) : rect(r),detection_confidence(score) {}
mmod_rect(const rectangle& r, double score, const std::string& label) : rect(r),detection_confidence(score),label(label) {}
rectangle rect;
double detection_confidence = 0;
bool ignore = false;
std::string label;
operator rectangle() const { return rect; }
bool operator == (const mmod_rect& rhs) const;
/*!
ensures
- returns true if and only if all the elements of this object compare equal
to the corresponding elements of rhs.
!*/
};
mmod_rect ignored_mmod_rect(
......@@ -176,6 +185,7 @@ namespace dlib
- R.rect == r
- R.ignore == true
- R.detection_confidence == 0
- R.label == ""
!*/
void serialize(const mmod_rect& item, std::ostream& out);
......
......@@ -155,10 +155,11 @@ namespace dlib
object_detector's except it runs on CNNs that use loss_mmod.
- Tests the given detector against the supplied object detection problem and
returns the precision, recall, and average precision. Note that the task is
to predict, for each images[i], the set of object locations given by
truth_dets[i]. Additionally, any detections on image[i] that match a box in
truth_dets[i] that are marked ignore are ignored. That is, detections
matching an ignore box do not count as a false alarm and similarly if any
to predict, for each images[i], the set of object locations, and their
corresponding labels, given by truth_dets[i]. Additionally, any detections
on image[i] that match a box in truth_dets[i] that are marked ignore are
ignored. That is, detections matching an ignore box, regardless of the
ignore box's label, do not count as a false alarm and similarly if any
ignored box in truth_dets goes undetected it does not count as a missed
detection. To test if a box overlaps an ignore box, we use overlaps_ignore_tester.
- In particular, returns a matrix M such that:
......@@ -178,8 +179,8 @@ namespace dlib
ordering them in descending order of their detection scores. Then we use
the average_precision() routine to score the ranked listing and store the
output into M(2).
- This function considers a detector output D to match a rectangle T if and
only if overlap_tester(T,D) returns true.
- This function considers a detector output D to match a truth rectangle T if
and only if overlap_tester(T,D) returns true and the labels are identical strings.
- Note that you can use the adjust_threshold argument to raise or lower the
detection threshold. This value is passed into the identically named
argument to the detector object and therefore influences the number of
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment