Commit f5574434 authored by Davis King's avatar Davis King

Upgraded loss_mmod_ to support objects of varying aspect ratio. This changes

the API for the mmod_options struct slightly.
parent dcfff1c4
...@@ -364,10 +364,37 @@ namespace dlib ...@@ -364,10 +364,37 @@ namespace dlib
{ {
public: public:
struct detector_window_size
{
detector_window_size() = default;
detector_window_size(unsigned long w, unsigned long h) : width(w), height(h) {}
unsigned long width = 0;
unsigned long height = 0;
friend inline void serialize(const detector_window_size& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.width, out);
serialize(item.height, out);
}
friend inline void deserialize(detector_window_size& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::mmod_options::detector_window_size");
deserialize(item.width, in);
deserialize(item.height, in);
}
};
mmod_options() = default; mmod_options() = default;
unsigned long detector_width = 80; std::vector<detector_window_size> detector_windows;
unsigned long detector_height = 80;
double loss_per_false_alarm = 1; double loss_per_false_alarm = 1;
double loss_per_missed_target = 1; double loss_per_missed_target = 1;
double truth_match_iou_threshold = 0.5; double truth_match_iou_threshold = 0.5;
...@@ -376,14 +403,51 @@ namespace dlib ...@@ -376,14 +403,51 @@ namespace dlib
mmod_options ( mmod_options (
const std::vector<std::vector<mmod_rect>>& boxes, const std::vector<std::vector<mmod_rect>>& boxes,
const unsigned long target_size = 6400 const unsigned long target_size, // We want the length of the longest dimension of the detector window to be this.
const unsigned long min_target_size, // But we require that the smallest dimension of the detector window be at least this big.
const double min_detector_window_overlap_iou = 0.75
) )
{ {
std::vector<std::vector<rectangle>> temp; DLIB_CASSERT(0 < min_target_size && min_target_size <= target_size);
DLIB_CASSERT(0.5 < min_detector_window_overlap_iou && min_detector_window_overlap_iou < 1);
// Figure out what detector windows we will need.
for (auto ratio : find_covering_aspect_ratios(boxes, test_box_overlap(min_detector_window_overlap_iou)))
{
double detector_width;
double detector_height;
if (ratio < 1)
{
detector_height = target_size;
detector_width = ratio*target_size;
if (detector_width < min_target_size)
{
detector_height = min_target_size;
detector_width = min_target_size/ratio;
}
}
else
{
detector_width = target_size;
detector_height = target_size/ratio;
if (detector_height < min_target_size)
{
detector_width = min_target_size;
detector_height = min_target_size*ratio;
}
}
// find the average width and height. Then we will set the detector width and detector_window_size p((unsigned long)std::round(detector_width), (unsigned long)std::round(detector_height));
// height to match the average aspect ratio of the boxes given the target_size. detector_windows.push_back(p);
running_stats<double> avg_width, avg_height; }
DLIB_CASSERT(detector_windows.size() != 0, "You can't call mmod_options's constructor with a set of boxes that is empty (or only contains ignored boxes).");
// Convert from mmod_rect to rectangle so we can call
// find_tight_overlap_tester().
std::vector<std::vector<rectangle>> temp;
for (auto&& bi : boxes) for (auto&& bi : boxes)
{ {
std::vector<rectangle> rtemp; std::vector<rectangle> rtemp;
...@@ -391,26 +455,10 @@ namespace dlib ...@@ -391,26 +455,10 @@ namespace dlib
{ {
if (b.ignore) if (b.ignore)
continue; continue;
avg_width.add(b.rect.width());
avg_height.add(b.rect.height());
rtemp.push_back(b.rect); rtemp.push_back(b.rect);
} }
temp.push_back(std::move(rtemp)); temp.push_back(std::move(rtemp));
} }
// now adjust the box size so that it is about target_pixels pixels in size
double size = avg_width.mean()*avg_height.mean();
double scale = std::sqrt(target_size/size);
detector_width = (unsigned long)(avg_width.mean()*scale+0.5);
detector_height = (unsigned long)(avg_height.mean()*scale+0.5);
// make sure the width and height never round to zero.
if (detector_width == 0)
detector_width = 1;
if (detector_height == 0)
detector_height = 1;
overlaps_nms = find_tight_overlap_tester(temp); overlaps_nms = find_tight_overlap_tester(temp);
// Relax the non-max-suppression a little so that it doesn't accidentally make // Relax the non-max-suppression a little so that it doesn't accidentally make
// it impossible for the detector to output boxes matching the training data. // it impossible for the detector to output boxes matching the training data.
...@@ -418,20 +466,106 @@ namespace dlib ...@@ -418,20 +466,106 @@ namespace dlib
// some small variability in how boxes get positioned between the training data // some small variability in how boxes get positioned between the training data
// and the coordinate system used by the detector when it runs. So relaxing it // and the coordinate system used by the detector when it runs. So relaxing it
// here takes care of that. // here takes care of that.
double relax_amount = 0.10; const double relax_amount = 0.05;
auto iou_thresh = std::min(1.0, overlaps_nms.get_iou_thresh()+relax_amount); auto iou_thresh = std::min(1.0, overlaps_nms.get_iou_thresh()+relax_amount);
auto percent_covered_thresh = std::min(1.0, overlaps_nms.get_percent_covered_thresh()+relax_amount); auto percent_covered_thresh = std::min(1.0, overlaps_nms.get_percent_covered_thresh()+relax_amount);
overlaps_nms = test_box_overlap(iou_thresh, percent_covered_thresh); overlaps_nms = test_box_overlap(iou_thresh, percent_covered_thresh);
} }
private:
static size_t count_overlaps (
const std::vector<rectangle>& rects,
const test_box_overlap& overlaps,
const rectangle& ref_box
)
{
size_t cnt = 0;
for (auto& b : rects)
{
if (overlaps(b, ref_box))
++cnt;
}
return cnt;
}
static std::vector<rectangle> find_rectangles_overlapping_all_others (
std::vector<rectangle> rects,
const test_box_overlap& overlaps
)
{
std::vector<rectangle> exemplars;
dlib::rand rnd;
while(rects.size() > 0)
{
// Pick boxes at random and see if they overlap a lot of other boxes. We will try
// 500 different boxes each iteration and select whichever hits the most others to
// add to our exemplar set.
rectangle best_ref_box;
size_t best_cnt = 0;
for (int iter = 0; iter < 500; ++iter)
{
rectangle ref_box = rects[rnd.get_random_64bit_number()%rects.size()];
size_t cnt = count_overlaps(rects, overlaps, ref_box);
if (cnt >= best_cnt)
{
best_cnt = cnt;
best_ref_box = ref_box;
}
}
// Now mark all the boxes the new ref box hit as hit.
for (size_t i = 0; i < rects.size(); ++i)
{
if (overlaps(rects[i], best_ref_box))
{
// remove box from rects so we don't hit it again later
swap(rects[i], rects.back());
rects.pop_back();
--i;
}
}
exemplars.push_back(best_ref_box);
}
return exemplars;
}
static std::vector<double> find_covering_aspect_ratios (
const std::vector<std::vector<mmod_rect>>& rects,
const test_box_overlap& overlaps
)
{
std::vector<rectangle> boxes;
// Make sure all the boxes have the same size and position, so that the only thing our
// checks for overlap will care about is aspect ratio (i.e. scale and x,y position are
// ignored).
for (auto& bb : rects)
{
for (auto&& b : bb)
{
if (!b.ignore)
boxes.push_back(move_rect(set_rect_area(b.rect,400*400), point(0,0)));
}
}
std::vector<double> ratios;
for (auto r : find_rectangles_overlapping_all_others(boxes, overlaps))
ratios.push_back(r.width()/(double)r.height());
return ratios;
}
}; };
inline void serialize(const mmod_options& item, std::ostream& out) inline void serialize(const mmod_options& item, std::ostream& out)
{ {
int version = 1; int version = 2;
serialize(version, out); serialize(version, out);
serialize(item.detector_width, out); serialize(item.detector_windows, out);
serialize(item.detector_height, out);
serialize(item.loss_per_false_alarm, out); serialize(item.loss_per_false_alarm, out);
serialize(item.loss_per_missed_target, out); serialize(item.loss_per_missed_target, out);
serialize(item.truth_match_iou_threshold, out); serialize(item.truth_match_iou_threshold, out);
...@@ -443,10 +577,20 @@ namespace dlib ...@@ -443,10 +577,20 @@ namespace dlib
{ {
int version = 0; int version = 0;
deserialize(version, in); deserialize(version, in);
if (version != 1) if (version != 2 && version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::mmod_options"); throw serialization_error("Unexpected version found while deserializing dlib::mmod_options");
deserialize(item.detector_width, in); if (version == 1)
deserialize(item.detector_height, in); {
unsigned long width;
unsigned long height;
deserialize(width, in);
deserialize(height, in);
item.detector_windows = {mmod_options::detector_window_size(width, height)};
}
else
{
deserialize(item.detector_windows, in);
}
deserialize(item.loss_per_false_alarm, in); deserialize(item.loss_per_false_alarm, in);
deserialize(item.loss_per_missed_target, in); deserialize(item.loss_per_missed_target, in);
deserialize(item.truth_match_iou_threshold, in); deserialize(item.truth_match_iou_threshold, in);
...@@ -503,7 +647,7 @@ namespace dlib ...@@ -503,7 +647,7 @@ namespace dlib
) const ) const
{ {
const tensor& output_tensor = sub.get_output(); const tensor& output_tensor = sub.get_output();
DLIB_CASSERT(output_tensor.k() == 1); DLIB_CASSERT(output_tensor.k() == options.detector_windows.size());
DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
DLIB_CASSERT(sub.sample_expansion_factor() == 1, sub.sample_expansion_factor()); DLIB_CASSERT(sub.sample_expansion_factor() == 1, sub.sample_expansion_factor());
...@@ -544,7 +688,7 @@ namespace dlib ...@@ -544,7 +688,7 @@ namespace dlib
DLIB_CASSERT(sub.sample_expansion_factor() == 1); DLIB_CASSERT(sub.sample_expansion_factor() == 1);
DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples()); DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
DLIB_CASSERT(output_tensor.k() == 1); DLIB_CASSERT(output_tensor.k() == options.detector_windows.size());
...@@ -574,10 +718,17 @@ namespace dlib ...@@ -574,10 +718,17 @@ namespace dlib
{ {
if (!x.ignore) if (!x.ignore)
{ {
point p = image_rect_to_feat_coord(input_tensor, x, sub); size_t k;
loss -= out_data[p.y()*output_tensor.nc() + p.x()]; point p;
if(image_rect_to_feat_coord(p, input_tensor, x, sub, k))
{
// Ignore boxes that can't be detected by the CNN.
loss -= 1;
continue;
}
loss -= out_data[(k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x()];
// compute gradient // compute gradient
g[p.y()*output_tensor.nc() + p.x()] = -scale; g[(k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x()] = -scale;
} }
else else
{ {
...@@ -670,8 +821,8 @@ namespace dlib ...@@ -670,8 +821,8 @@ namespace dlib
} }
++truth; ++truth;
g += output_tensor.nr()*output_tensor.nc(); g += output_tensor.k()*output_tensor.nr()*output_tensor.nc();
out_data += output_tensor.nr()*output_tensor.nc(); out_data += output_tensor.k()*output_tensor.nr()*output_tensor.nc();
} // END for (long i = 0; i < output_tensor.num_samples(); ++i) } // END for (long i = 0; i < output_tensor.num_samples(); ++i)
...@@ -726,34 +877,64 @@ namespace dlib ...@@ -726,34 +877,64 @@ namespace dlib
) const ) const
{ {
DLIB_CASSERT(net.sample_expansion_factor() == 1,net.sample_expansion_factor()); DLIB_CASSERT(net.sample_expansion_factor() == 1,net.sample_expansion_factor());
DLIB_CASSERT(output_tensor.k() == 1); DLIB_CASSERT(output_tensor.k() == options.detector_windows.size());
const float* out_data = output_tensor.host() + output_tensor.nr()*output_tensor.nc()*i; const float* out_data = output_tensor.host() + output_tensor.k()*output_tensor.nr()*output_tensor.nc()*i;
// scan the final layer and output the positive scoring locations // scan the final layer and output the positive scoring locations
dets_accum.clear(); dets_accum.clear();
for (long r = 0; r < output_tensor.nr(); ++r) for (long k = 0; k < output_tensor.k(); ++k)
{ {
for (long c = 0; c < output_tensor.nc(); ++c) for (long r = 0; r < output_tensor.nr(); ++r)
{ {
double score = out_data[r*output_tensor.nc() + c]; for (long c = 0; c < output_tensor.nc(); ++c)
if (score > adjust_threshold)
{ {
dpoint p = output_tensor_to_input_tensor(net, point(c,r)); double score = out_data[(k*output_tensor.nr() + r)*output_tensor.nc() + c];
drectangle rect = centered_drect(p, options.detector_width, options.detector_height); if (score > adjust_threshold)
rect = input_layer(net).tensor_space_to_image_space(input_tensor,rect); {
dpoint p = output_tensor_to_input_tensor(net, point(c,r));
drectangle rect = centered_drect(p, options.detector_windows[k].width, options.detector_windows[k].height);
rect = input_layer(net).tensor_space_to_image_space(input_tensor,rect);
dets_accum.push_back(intermediate_detection(rect, score, r*output_tensor.nc() + c)); dets_accum.push_back(intermediate_detection(rect, score, (k*output_tensor.nr() + r)*output_tensor.nc() + c));
}
} }
} }
} }
std::sort(dets_accum.rbegin(), dets_accum.rend()); std::sort(dets_accum.rbegin(), dets_accum.rend());
} }
size_t find_best_detection_window (
rectangle rect
) const
{
rect = move_rect(set_rect_area(rect, 400*400), point(0,0));
// Figure out which detection window in options.detector_windows has the most
// similar aspect ratio to rect.
const double aspect_ratio = rect.width()/(double)rect.height();
size_t best_i = 0;
double best_ratio_diff = -std::numeric_limits<double>::infinity();
for (size_t i = 0; i < options.detector_windows.size(); ++i)
{
rectangle det_window = centered_rect(point(0,0), options.detector_windows[i].width, options.detector_windows[i].height);
det_window = move_rect(set_rect_area(det_window, 400*400), point(0,0));
double iou = box_intersection_over_union(rect, det_window);
if (iou > best_ratio_diff)
{
best_ratio_diff = iou;
best_i = i;
}
}
return best_i;
}
template <typename net_type> template <typename net_type>
point image_rect_to_feat_coord ( bool image_rect_to_feat_coord (
point& tensor_p,
const tensor& input_tensor, const tensor& input_tensor,
const rectangle& rect, const rectangle& rect,
const net_type& net const net_type& net,
size_t& det_idx
) const ) const
{ {
using namespace std; using namespace std;
...@@ -765,38 +946,39 @@ namespace dlib ...@@ -765,38 +946,39 @@ namespace dlib
throw impossible_labeling_error(sout.str()); throw impossible_labeling_error(sout.str());
} }
det_idx = find_best_detection_window(rect);
// Compute the scale we need to be at to get from rect to our detection window. // Compute the scale we need to be at to get from rect to our detection window.
// Note that we compute the scale as the max of two numbers. It doesn't // Note that we compute the scale as the max of two numbers. It doesn't
// actually matter which one we pick, because if they are very different then // actually matter which one we pick, because if they are very different then
// it means the box can't be matched by the sliding window. But picking the // it means the box can't be matched by the sliding window. But picking the
// max causes the right error message to be selected in the logic below. // max causes the right error message to be selected in the logic below.
const double scale = std::max(options.detector_width/(double)rect.width(), options.detector_height/(double)rect.height()); const double scale = std::max(options.detector_windows[det_idx].width/(double)rect.width(), options.detector_windows[det_idx].height/(double)rect.height());
const rectangle mapped_rect = input_layer(net).image_space_to_tensor_space(input_tensor, std::min(1.0,scale), rect); const rectangle mapped_rect = input_layer(net).image_space_to_tensor_space(input_tensor, std::min(1.0,scale), rect);
// compute the detection window that we would use at this position. // compute the detection window that we would use at this position.
point tensor_p = center(mapped_rect); tensor_p = center(mapped_rect);
rectangle det_window = centered_rect(tensor_p, options.detector_width,options.detector_height); rectangle det_window = centered_rect(tensor_p, options.detector_windows[det_idx].width,options.detector_windows[det_idx].height);
det_window = input_layer(net).tensor_space_to_image_space(input_tensor, det_window); det_window = input_layer(net).tensor_space_to_image_space(input_tensor, det_window);
// make sure the rect can actually be represented by the image pyramid we are // make sure the rect can actually be represented by the image pyramid we are
// using. // using.
if (box_intersection_over_union(rect, det_window) <= options.truth_match_iou_threshold) if (box_intersection_over_union(rect, det_window) <= options.truth_match_iou_threshold)
{ {
std::ostringstream sout; std::cout << "Warning, ignoring object. We encountered a truth rectangle with a width and height of " << rect.width() << " and " << rect.height() << ". ";
sout << "Encountered a truth rectangle with a width and height of " << rect.width() << " and " << rect.height() << "." << endl; std::cout << "The image pyramid and sliding windows can't output a rectangle of this shape. ";
sout << "The image pyramid and sliding window can't output a rectangle of this shape. " << endl; const double detector_area = options.detector_windows[det_idx].width*options.detector_windows[det_idx].height;
const double detector_area = options.detector_width*options.detector_height;
if (mapped_rect.area()/detector_area <= options.truth_match_iou_threshold) if (mapped_rect.area()/detector_area <= options.truth_match_iou_threshold)
{ {
sout << "This is because the rectangle is smaller than the detection window which has a width" << endl; std::cout << "This is because the rectangle is smaller than the best matching detection window, which has a width ";
sout << "and height of " << options.detector_width << " and " << options.detector_height << "." << endl; std::cout << "and height of " << options.detector_windows[det_idx].width << " and " << options.detector_windows[det_idx].height << "." << std::endl;
} }
else else
{ {
sout << "This is because the rectangle's aspect ratio is too different from the detection window," << endl; std::cout << "This is because the rectangle's aspect ratio is too different from the best matching detection window, ";
sout << "which has a width and height of " << options.detector_width << " and " << options.detector_height << "." << endl; std::cout << "which has a width and height of " << options.detector_windows[det_idx].width << " and " << options.detector_windows[det_idx].height << "." << std::endl;
} }
throw impossible_labeling_error(sout.str()); return true;
} }
// now map through the CNN to the output layer. // now map through the CNN to the output layer.
...@@ -805,13 +987,12 @@ namespace dlib ...@@ -805,13 +987,12 @@ namespace dlib
const tensor& output_tensor = net.get_output(); const tensor& output_tensor = net.get_output();
if (!get_rect(output_tensor).contains(tensor_p)) if (!get_rect(output_tensor).contains(tensor_p))
{ {
std::ostringstream sout; std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << rect << " that is too close to the edge ";
sout << "Encountered a truth rectangle located at " << rect << " that is too close to the edge" << endl; std::cout << "of the image to be captured by the CNN features." << std::endl;
sout << "of the image to be captured by the CNN features." << endl; return true;
throw impossible_labeling_error(sout.str());
} }
return tensor_p; return false;
} }
......
...@@ -369,14 +369,25 @@ namespace dlib ...@@ -369,14 +369,25 @@ namespace dlib
public: public:
struct detector_window_size
{
detector_window_size() = default;
detector_window_size(unsigned long w, unsigned long h) : width(w), height(h) {}
unsigned long width = 0;
unsigned long height = 0;
friend inline void serialize(const detector_window_size& item, std::ostream& out);
friend inline void deserialize(detector_window_size& item, std::istream& in);
};
mmod_options() = default; mmod_options() = default;
// This kind of object detector is a sliding window detector. These two parameters // This kind of object detector is a sliding window detector. The detector_windows
// determine the size of the sliding window. Since you will usually use the MMOD // field determines how many sliding windows we will use and what the shape of each
// loss with an image pyramid the detector size determines the size of the smallest // window is. Since you will usually use the MMOD loss with an image pyramid, the
// object you can detect. // detector sizes also determine the size of the smallest object you can detect.
unsigned long detector_width = 80; std::vector<detector_window_size> detector_windows;
unsigned long detector_height = 80;
// These parameters control how we penalize different kinds of mistakes. See // These parameters control how we penalize different kinds of mistakes. See
// Max-Margin Object Detection by Davis E. King (http://arxiv.org/abs/1502.00046) // Max-Margin Object Detection by Davis E. King (http://arxiv.org/abs/1502.00046)
...@@ -402,18 +413,38 @@ namespace dlib ...@@ -402,18 +413,38 @@ namespace dlib
mmod_options ( mmod_options (
const std::vector<std::vector<mmod_rect>>& boxes, const std::vector<std::vector<mmod_rect>>& boxes,
const unsigned long target_size = 6400 const unsigned long target_size,
const unsigned long min_target_size,
const double min_detector_window_overlap_iou = 0.75
); );
/*! /*!
requires
- 0 < min_target_size <= target_size
- 0.5 < min_detector_window_overlap_iou < 1
ensures ensures
- This function tries to automatically set the MMOD options so reasonable - This function tries to automatically set the MMOD options to reasonable
values assuming you have a training dataset of boxes.size() images, where values, assuming you have a training dataset of boxes.size() images, where
the ith image contains objects boxes[i] you want to detect and the the ith image contains objects boxes[i] you want to detect.
objects are clearly visible when scale so that they are target_size - The most important thing this function does is decide what detector
pixels in area. windows should be used. This is done by finding a set of detector
- In particular, this function will automatically set the detector width windows that are sized such that:
and height based on the average box size in boxes and the requested - When slid over an image pyramid, each box in boxes will have an
target_size. intersection-over-union with one of the detector windows of at least
min_detector_window_overlap_iou. That is, we will make sure that
each box in boxes could potentially be detected by one of the
detector windows. This essentially comes down to picking detector
windows with aspect ratios similar to the aspect ratios in boxes.
- The longest edge of each detector window is target_size pixels in
length, unless the window's shortest side would be less than
min_target_size pixels in length. In this case the shortest side
will be set to min_target_size length, and the other side sized to
preserve the aspect ratio of the window.
This means that, target_size and min_target_size control the size of the
detector windows, while the aspect ratios of the detector windows are
automatically determined by the contents of boxes. It should also be
emphasized that the detector isn't going to be able to detect objects
smaller than any of the detector windows. So consider that when setting
these sizes.
- This function will also set the overlaps_nms tester to the most - This function will also set the overlaps_nms tester to the most
restrictive tester that doesn't reject anything in boxes. restrictive tester that doesn't reject anything in boxes.
!*/ !*/
......
...@@ -131,13 +131,18 @@ int main(int argc, char** argv) try ...@@ -131,13 +131,18 @@ int main(int argc, char** argv) try
// pick a good sliding window width and height. It will also automatically set the // pick a good sliding window width and height. It will also automatically set the
// non-max-suppression parameters to something reasonable. For further details see the // non-max-suppression parameters to something reasonable. For further details see the
// mmod_options documentation. // mmod_options documentation.
mmod_options options(face_boxes_train, 40*40); mmod_options options(face_boxes_train, 40,40);
cout << "detection window width,height: " << options.detector_width << "," << options.detector_height << endl; // The detector will automatically decide to use multiple sliding windows if needed.
// For the face data, only one is needed however.
cout << "num detector windows: "<< options.detector_windows.size() << endl;
for (auto& w : options.detector_windows)
cout << "detector window width by height: " << w.width << " x " << w.height << endl;
cout << "overlap NMS IOU thresh: " << options.overlaps_nms.get_iou_thresh() << endl; cout << "overlap NMS IOU thresh: " << options.overlaps_nms.get_iou_thresh() << endl;
cout << "overlap NMS percent covered thresh: " << options.overlaps_nms.get_percent_covered_thresh() << endl; cout << "overlap NMS percent covered thresh: " << options.overlaps_nms.get_percent_covered_thresh() << endl;
// Now we are ready to create our network and trainer. // Now we are ready to create our network and trainer.
net_type net(options); net_type net(options);
net.subnet().layer_details().set_num_filters(options.detector_windows.size());
dnn_trainer<net_type> trainer(net); dnn_trainer<net_type> trainer(net);
trainer.set_learning_rate(0.1); trainer.set_learning_rate(0.1);
trainer.be_verbose(); trainer.be_verbose();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment