Commit f5574434 authored by Davis King's avatar Davis King

Upgraded loss_mmod_ to support objects of varying aspect ratio. This changes

the API for the mmod_options struct slightly.
parent dcfff1c4
This diff is collapsed.
...@@ -369,14 +369,25 @@ namespace dlib ...@@ -369,14 +369,25 @@ namespace dlib
public: public:
struct detector_window_size
{
detector_window_size() = default;
detector_window_size(unsigned long w, unsigned long h) : width(w), height(h) {}
unsigned long width = 0;
unsigned long height = 0;
friend inline void serialize(const detector_window_size& item, std::ostream& out);
friend inline void deserialize(detector_window_size& item, std::istream& in);
};
mmod_options() = default; mmod_options() = default;
// This kind of object detector is a sliding window detector. These two parameters // This kind of object detector is a sliding window detector. The detector_windows
// determine the size of the sliding window. Since you will usually use the MMOD // field determines how many sliding windows we will use and what the shape of each
// loss with an image pyramid the detector size determines the size of the smallest // window is. Since you will usually use the MMOD loss with an image pyramid, the
// object you can detect. // detector sizes also determine the size of the smallest object you can detect.
unsigned long detector_width = 80; std::vector<detector_window_size> detector_windows;
unsigned long detector_height = 80;
// These parameters control how we penalize different kinds of mistakes. See // These parameters control how we penalize different kinds of mistakes. See
// Max-Margin Object Detection by Davis E. King (http://arxiv.org/abs/1502.00046) // Max-Margin Object Detection by Davis E. King (http://arxiv.org/abs/1502.00046)
...@@ -402,18 +413,38 @@ namespace dlib ...@@ -402,18 +413,38 @@ namespace dlib
mmod_options ( mmod_options (
const std::vector<std::vector<mmod_rect>>& boxes, const std::vector<std::vector<mmod_rect>>& boxes,
const unsigned long target_size = 6400 const unsigned long target_size,
const unsigned long min_target_size,
const double min_detector_window_overlap_iou = 0.75
); );
/*! /*!
requires
- 0 < min_target_size <= target_size
- 0.5 < min_detector_window_overlap_iou < 1
ensures ensures
- This function tries to automatically set the MMOD options so reasonable - This function tries to automatically set the MMOD options to reasonable
values assuming you have a training dataset of boxes.size() images, where values, assuming you have a training dataset of boxes.size() images, where
the ith image contains objects boxes[i] you want to detect and the the ith image contains objects boxes[i] you want to detect.
objects are clearly visible when scale so that they are target_size - The most important thing this function does is decide what detector
pixels in area. windows should be used. This is done by finding a set of detector
- In particular, this function will automatically set the detector width windows that are sized such that:
and height based on the average box size in boxes and the requested - When slid over an image pyramid, each box in boxes will have an
target_size. intersection-over-union with one of the detector windows of at least
min_detector_window_overlap_iou. That is, we will make sure that
each box in boxes could potentially be detected by one of the
detector windows. This essentially comes down to picking detector
windows with aspect ratios similar to the aspect ratios in boxes.
- The longest edge of each detector window is target_size pixels in
length, unless the window's shortest side would be less than
min_target_size pixels in length. In this case the shortest side
will be set to min_target_size length, and the other side sized to
preserve the aspect ratio of the window.
This means that, target_size and min_target_size control the size of the
detector windows, while the aspect ratios of the detector windows are
automatically determined by the contents of boxes. It should also be
emphasized that the detector isn't going to be able to detect objects
smaller than any of the detector windows. So consider that when setting
these sizes.
- This function will also set the overlaps_nms tester to the most - This function will also set the overlaps_nms tester to the most
restrictive tester that doesn't reject anything in boxes. restrictive tester that doesn't reject anything in boxes.
!*/ !*/
......
...@@ -131,13 +131,18 @@ int main(int argc, char** argv) try ...@@ -131,13 +131,18 @@ int main(int argc, char** argv) try
// pick a good sliding window width and height. It will also automatically set the // pick a good sliding window width and height. It will also automatically set the
// non-max-suppression parameters to something reasonable. For further details see the // non-max-suppression parameters to something reasonable. For further details see the
// mmod_options documentation. // mmod_options documentation.
mmod_options options(face_boxes_train, 40*40); mmod_options options(face_boxes_train, 40,40);
cout << "detection window width,height: " << options.detector_width << "," << options.detector_height << endl; // The detector will automatically decide to use multiple sliding windows if needed.
// For the face data, only one is needed however.
cout << "num detector windows: "<< options.detector_windows.size() << endl;
for (auto& w : options.detector_windows)
cout << "detector window width by height: " << w.width << " x " << w.height << endl;
cout << "overlap NMS IOU thresh: " << options.overlaps_nms.get_iou_thresh() << endl; cout << "overlap NMS IOU thresh: " << options.overlaps_nms.get_iou_thresh() << endl;
cout << "overlap NMS percent covered thresh: " << options.overlaps_nms.get_percent_covered_thresh() << endl; cout << "overlap NMS percent covered thresh: " << options.overlaps_nms.get_percent_covered_thresh() << endl;
// Now we are ready to create our network and trainer. // Now we are ready to create our network and trainer.
net_type net(options); net_type net(options);
net.subnet().layer_details().set_num_filters(options.detector_windows.size());
dnn_trainer<net_type> trainer(net); dnn_trainer<net_type> trainer(net);
trainer.set_learning_rate(0.1); trainer.set_learning_rate(0.1);
trainer.be_verbose(); trainer.be_verbose();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment