Commit f484cec7 authored by Juha Reunanen's avatar Juha Reunanen Committed by Davis E. King

Add support for non-scale-invariant MMOD (#809)

* Add capability to train scale-variant MMOD models

* Review fixes: change bool scale_invariant to strongly typed enum, etc

* Add serialization and deserialization of assumed_input_layer_type

* Fix code formatting

* Rename things as per review feedback

* Review fix: move enum use_image_pyramid outside mmod_options

* Continue execution with net, if deserialization of shape predictor fails

* Revert "Continue execution with net, if deserialization of shape predictor fails"

This reverts commit 8ea4482c043b5b98b97ed5b78bfc6916a1e2a453.
parent 36b04362
......@@ -361,6 +361,12 @@ namespace dlib
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
enum class use_image_pyramid : uint8_t
{
no,
yes
};
struct mmod_options
{
public:
......@@ -407,6 +413,8 @@ namespace dlib
test_box_overlap overlaps_nms = test_box_overlap(0.4);
test_box_overlap overlaps_ignore;
use_image_pyramid assume_image_pyramid = use_image_pyramid::yes;
mmod_options (
const std::vector<std::vector<mmod_rect>>& boxes,
const unsigned long target_size, // We want the length of the longest dimension of the detector window to be this.
......@@ -452,8 +460,37 @@ namespace dlib
DLIB_CASSERT(detector_windows.size() != 0, "You can't call mmod_options's constructor with a set of boxes that is empty (or only contains ignored boxes).");
set_overlap_nms(boxes);
}
mmod_options(
use_image_pyramid assume_image_pyramid,
const std::vector<std::vector<mmod_rect>>& boxes,
const double min_detector_window_overlap_iou = 0.75
)
: assume_image_pyramid(assume_image_pyramid)
{
DLIB_CASSERT(assume_image_pyramid == use_image_pyramid::no);
DLIB_CASSERT(0.5 < min_detector_window_overlap_iou && min_detector_window_overlap_iou < 1);
// Figure out what detector windows we will need.
for (auto& label : get_labels(boxes))
{
for (auto rectangle : find_covering_rectangles(boxes, test_box_overlap(min_detector_window_overlap_iou), label))
{
detector_windows.push_back(detector_window_details(rectangle.width(), rectangle.height(), label));
}
}
DLIB_CASSERT(detector_windows.size() != 0, "You can't call mmod_options's constructor with a set of boxes that is empty (or only contains ignored boxes).");
set_overlap_nms(boxes);
}
private:
void set_overlap_nms(const std::vector<std::vector<mmod_rect>>& boxes)
{
// Convert from mmod_rect to rectangle so we can call
// find_tight_overlap_tester().
std::vector<std::vector<rectangle>> temp;
......@@ -480,9 +517,6 @@ namespace dlib
overlaps_nms = test_box_overlap(iou_thresh, percent_covered_thresh);
}
private:
static double advance_toward_1 (
double val
)
......@@ -589,11 +623,31 @@ namespace dlib
return ratios;
}
static std::vector<dlib::rectangle> find_covering_rectangles (
const std::vector<std::vector<mmod_rect>>& rects,
const test_box_overlap& overlaps,
const std::string& label
)
{
std::vector<rectangle> boxes;
// Make sure all the boxes have the same position, so that the we only check for
// width and height.
for (auto& bb : rects)
{
for (auto&& b : bb)
{
if (!b.ignore && b.label == label)
boxes.push_back(rectangle(b.rect.width(), b.rect.height()));
}
}
return find_rectangles_overlapping_all_others(boxes, overlaps);
}
};
inline void serialize(const mmod_options& item, std::ostream& out)
{
int version = 2;
int version = 3;
serialize(version, out);
serialize(item.detector_windows, out);
......@@ -602,13 +656,14 @@ namespace dlib
serialize(item.truth_match_iou_threshold, out);
serialize(item.overlaps_nms, out);
serialize(item.overlaps_ignore, out);
serialize(static_cast<uint8_t>(item.assume_image_pyramid), out);
}
inline void deserialize(mmod_options& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 2 && version != 1)
if (version != 3 && version != 2 && version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::mmod_options");
if (version == 1)
{
......@@ -627,6 +682,12 @@ namespace dlib
deserialize(item.truth_match_iou_threshold, in);
deserialize(item.overlaps_nms, in);
deserialize(item.overlaps_ignore, in);
if (version >= 3)
{
uint8_t assume_image_pyramid = 0;
deserialize(assume_image_pyramid, in);
item.assume_image_pyramid = static_cast<use_image_pyramid>(assume_image_pyramid);
}
}
// ----------------------------------------------------------------------------------------
......@@ -764,7 +825,7 @@ namespace dlib
{
size_t k;
point p;
if(image_rect_to_feat_coord(p, input_tensor, x, x.label, sub, k))
if(image_rect_to_feat_coord(p, input_tensor, x, x.label, sub, k, options.assume_image_pyramid))
{
// Ignore boxes that can't be detected by the CNN.
loss -= options.loss_per_missed_target;
......@@ -1000,21 +1061,39 @@ namespace dlib
size_t find_best_detection_window (
rectangle rect,
const std::string& label
const std::string& label,
use_image_pyramid assume_image_pyramid
) const
{
rect = move_rect(set_rect_area(rect, 400*400), point(0,0));
if (assume_image_pyramid == use_image_pyramid::yes)
{
rect = move_rect(set_rect_area(rect, 400*400), point(0,0));
}
else
{
rect = rectangle(rect.width(), rect.height());
}
// Figure out which detection window in options.detector_windows has the most
// similar aspect ratio to rect.
// Figure out which detection window in options.detector_windows is most similar to rect
// (in terms of aspect ratio, if assume_image_pyramid == use_image_pyramid::yes).
size_t best_i = 0;
double best_ratio_diff = -std::numeric_limits<double>::infinity();
for (size_t i = 0; i < options.detector_windows.size(); ++i)
{
if (options.detector_windows[i].label != label)
continue;
rectangle det_window = centered_rect(point(0,0), options.detector_windows[i].width, options.detector_windows[i].height);
det_window = move_rect(set_rect_area(det_window, 400*400), point(0,0));
rectangle det_window;
if (options.assume_image_pyramid == use_image_pyramid::yes)
{
det_window = centered_rect(point(0,0), options.detector_windows[i].width, options.detector_windows[i].height);
det_window = move_rect(set_rect_area(det_window, 400*400), point(0,0));
}
else
{
det_window = rectangle(options.detector_windows[i].width, options.detector_windows[i].height);
}
double iou = box_intersection_over_union(rect, det_window);
if (iou > best_ratio_diff)
......@@ -1033,7 +1112,8 @@ namespace dlib
const rectangle& rect,
const std::string& label,
const net_type& net,
size_t& det_idx
size_t& det_idx,
use_image_pyramid assume_image_pyramid
) const
{
using namespace std;
......@@ -1045,14 +1125,24 @@ namespace dlib
throw impossible_labeling_error(sout.str());
}
det_idx = find_best_detection_window(rect,label);
det_idx = find_best_detection_window(rect,label,assume_image_pyramid);
double scale = 1.0;
if (options.assume_image_pyramid == use_image_pyramid::yes)
{
// Compute the scale we need to be at to get from rect to our detection window.
// Note that we compute the scale as the max of two numbers. It doesn't
// actually matter which one we pick, because if they are very different then
// it means the box can't be matched by the sliding window. But picking the
// max causes the right error message to be selected in the logic below.
scale = std::max(options.detector_windows[det_idx].width/(double)rect.width(), options.detector_windows[det_idx].height/(double)rect.height());
}
else
{
// We don't want invariance to scale.
scale = 1.0;
}
// Compute the scale we need to be at to get from rect to our detection window.
// Note that we compute the scale as the max of two numbers. It doesn't
// actually matter which one we pick, because if they are very different then
// it means the box can't be matched by the sliding window. But picking the
// max causes the right error message to be selected in the logic below.
const double scale = std::max(options.detector_windows[det_idx].width/(double)rect.width(), options.detector_windows[det_idx].height/(double)rect.height());
const rectangle mapped_rect = input_layer(net).image_space_to_tensor_space(input_tensor, std::min(1.0,scale), rect);
// compute the detection window that we would use at this position.
......
......@@ -365,6 +365,12 @@ namespace dlib
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
enum class use_image_pyramid : uint8_t
{
no,
yes
};
struct mmod_options
{
/*!
......@@ -420,6 +426,10 @@ namespace dlib
// don't care if the detector gets them or not.
test_box_overlap overlaps_ignore;
// Usually the detector would be scale-invariant, and used with an image pyramid.
// However, sometimes scale-invariance may not be desired.
use_image_pyramid assume_image_pyramid = use_image_pyramid::yes;
mmod_options (
const std::vector<std::vector<mmod_rect>>& boxes,
const unsigned long target_size,
......@@ -431,6 +441,9 @@ namespace dlib
- 0 < min_target_size <= target_size
- 0.5 < min_detector_window_overlap_iou < 1
ensures
- use_image_pyramid_ == use_image_pyramid::yes
- This function should be used when scale-invariance is desired, and
input_rgb_image_pyramid is therefore used as the input layer.
- This function tries to automatically set the MMOD options to reasonable
values, assuming you have a training dataset of boxes.size() images, where
the ith image contains objects boxes[i] you want to detect.
......@@ -461,6 +474,33 @@ namespace dlib
- This function will also set the overlaps_nms tester to the most
restrictive tester that doesn't reject anything in boxes.
!*/
mmod_options (
use_image_pyramid use_image_pyramid,
const std::vector<std::vector<mmod_rect>>& boxes,
const double min_detector_window_overlap_iou = 0.75
);
/*!
requires
- use_image_pyramid == use_image_pyramid::no
- 0.5 < min_detector_window_overlap_iou < 1
ensures
- This function should be used when scale-invariance is not desired, and
there is no intention to apply an image pyramid.
- This function tries to automatically set the MMOD options to reasonable
values, assuming you have a training dataset of boxes.size() images, where
the ith image contains objects boxes[i] you want to detect.
- The most important thing this function does is decide what detector
windows should be used. This is done by finding a set of detector
windows that are sized such that:
- When slid over an image, each box in boxes will have an
intersection-over-union with one of the detector windows of at least
min_detector_window_overlap_iou. That is, we will make sure that
each box in boxes could potentially be detected by one of the
detector windows.
- This function will also set the overlaps_nms tester to the most
restrictive tester that doesn't reject anything in boxes.
!*/
};
void serialize(const mmod_options& item, std::ostream& out);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment