Commit 29db3ee5 authored by Davis King's avatar Davis King

Added missing input validation to loss_mmod_. Specifically, the loss layer now

checks if the user is giving truth boxes that can't be detected because the
non-max suppression settings would prevent them from being output at the same
time.  If this happens then we print a warning message and set one of the
offending boxes to "ignore".
parent bf55c4e8
......@@ -697,12 +697,12 @@ namespace dlib
double loss = 0;
float* g = grad.host_write_only();
// zero initialize grad.
for (auto&& x : grad)
x = 0;
for (size_t i = 0; i < grad.size(); ++i)
g[i] = 0;
const float* out_data = output_tensor.host();
std::vector<size_t> truth_idxs; truth_idxs.reserve(truth->size());
std::vector<intermediate_detection> dets;
for (long i = 0; i < output_tensor.num_samples(); ++i)
{
......@@ -726,14 +726,17 @@ namespace dlib
loss -= 1;
continue;
}
loss -= out_data[(k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x()];
const size_t idx = (k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x();
loss -= out_data[idx];
// compute gradient
g[(k*output_tensor.nr() + p.y())*output_tensor.nc() + p.x()] = -scale;
g[idx] = -scale;
truth_idxs.push_back(idx);
}
else
{
// This box was ignored so shouldn't have been counted in the loss.
loss -= 1;
truth_idxs.push_back(0);
}
}
......@@ -772,6 +775,33 @@ namespace dlib
}
}
// Check if any of the truth boxes are unobtainable because the NMS is
// killing them. If so, automatically set those unobtainable boxes to
// ignore and print a warning message to the user.
for (size_t i = 0; i < hit_truth_table.size(); ++i)
{
if (!hit_truth_table[i] && !(*truth)[i].ignore)
{
// So we didn't hit this truth box. Is that because there is
// another, different truth box, that overlaps it according to NMS?
const std::pair<double,unsigned int> hittruth = find_best_match(*truth, (*truth)[i], i);
if (hittruth.second == i)
continue;
rectangle best_matching_truth_box = (*truth)[hittruth.second];
if (options.overlaps_nms(best_matching_truth_box, (*truth)[i]))
{
const size_t idx = truth_idxs[i];
// We are ignoring this box so we shouldn't have counted it in the
// loss in the first place. So we subtract out the loss values we
// added for it in the code above.
loss -= 1-out_data[idx];
g[idx] = 0;
std::cout << "Warning, ignoring object. We encountered a truth rectangle located at " << (*truth)[i].rect;
std::cout << " that is suppressed by non-max-suppression ";
std::cout << "because it is overlapped by another truth rectangle located at " << best_matching_truth_box << "." << std::endl;
}
}
}
hit_truth_table.assign(hit_truth_table.size(), false);
final_dets.clear();
......@@ -1012,12 +1042,21 @@ namespace dlib
const std::vector<mmod_rect>& boxes,
const rectangle& rect
) const
{
return find_best_match(boxes, rect, boxes.size());
}
std::pair<double,unsigned int> find_best_match(
const std::vector<mmod_rect>& boxes,
const rectangle& rect,
const size_t excluded_idx
) const
{
double match = 0;
unsigned int best_idx = 0;
for (unsigned long i = 0; i < boxes.size(); ++i)
{
if (boxes[i].ignore)
if (boxes[i].ignore || excluded_idx == i)
continue;
const double new_match = box_intersection_over_union(rect, boxes[i]);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment