Commit 1928723a authored by Davis King's avatar Davis King

Moved label_to_ignore into loss_multiclass_log_per_pixel_ and also cleaned up a

few minor things.
parent 4bc6c1e5
......@@ -1531,15 +1531,16 @@ namespace dlib
// ----------------------------------------------------------------------------------------
// In semantic segmentation, if you don't know the ground-truth of some pixel,
// set the label of that pixel to this value. When you do so, the pixel will be
// ignored when computing gradients.
static const uint16_t label_to_ignore = std::numeric_limits<uint16_t>::max();
class loss_multiclass_log_per_pixel_
{
public:
// In semantic segmentation, if you don't know the ground-truth of some pixel,
// set the label of that pixel to this value. When you do so, the pixel will be
// ignored when computing gradients.
static const uint16_t label_to_ignore = std::numeric_limits<uint16_t>::max();
// In semantic segmentation, 65535 classes ought to be enough for anybody.
typedef matrix<uint16_t> training_label_type;
typedef matrix<uint16_t> output_label_type;
......@@ -1565,12 +1566,15 @@ namespace dlib
const float* const out_data = output_tensor.host();
// The index of the largest output for each element is the label.
const auto find_label = [&](long sample, long r, long c) {
const auto find_label = [&](long sample, long r, long c)
{
uint16_t label = 0;
float max_value = out_data[tensor_index(output_tensor, sample, r, c, 0)];
for (long k = 1; k < output_tensor.k(); ++k) {
for (long k = 1; k < output_tensor.k(); ++k)
{
const float value = out_data[tensor_index(output_tensor, sample, r, c, k)];
if (value > max_value) {
if (value > max_value)
{
label = static_cast<uint16_t>(k);
max_value = value;
}
......@@ -1578,10 +1582,13 @@ namespace dlib
return label;
};
for (long i = 0; i < output_tensor.num_samples(); ++i, ++iter) {
for (long i = 0; i < output_tensor.num_samples(); ++i, ++iter)
{
iter->set_size(output_tensor.nr(), output_tensor.nc());
for (long r = 0; r < output_tensor.nr(); ++r) {
for (long c = 0; c < output_tensor.nc(); ++c) {
for (long r = 0; r < output_tensor.nr(); ++r)
{
for (long c = 0; c < output_tensor.nc(); ++c)
{
// The index of the largest output for this element is the label.
iter->operator()(r, c) = find_label(i, r, c);
}
......
......@@ -808,12 +808,17 @@ namespace dlib
EXAMPLE_LOSS_LAYER_. In particular, it implements the multiclass logistic
regression loss (e.g. negative log-likelihood loss), which is appropriate
for multiclass classification problems. It is basically just like
loss_multiclass_log_ except that it lets you define matrix output instead
of scalar. It should be useful, for example, in semantic segmentation where
we want to classify each pixel of an image.
loss_multiclass_log_ except that it lets you define matrix outputs instead
of scalar outputs. It should be useful, for example, in semantic
segmentation where we want to classify each pixel of an image.
!*/
public:
// In semantic segmentation, if you don't know the ground-truth of some pixel,
// set the label of that pixel to this value. When you do so, the pixel will be
// ignored when computing gradients.
static const uint16_t label_to_ignore = std::numeric_limits<uint16_t>::max();
// In semantic segmentation, 65535 classes ought to be enough for anybody.
typedef matrix<uint16_t> training_label_type;
typedef matrix<uint16_t> output_label_type;
......@@ -850,7 +855,7 @@ namespace dlib
except it has the additional calling requirements that:
- sub.get_output().num_samples() == input_tensor.num_samples()
- sub.sample_expansion_factor() == 1
- all values pointed to by truth are < sub.get_output().k() (or std::numeric_limits<uint16_t>::max() to ignore)
- all values pointed to by truth are < sub.get_output().k() or are equal to label_to_ignore.
!*/
};
......
......@@ -2243,7 +2243,7 @@ namespace
DLIB_TEST(truth < num_classes);
++truth_histogram[truth];
if (ignore(generator)) {
ytmp(jj, kk) = label_to_ignore;
ytmp(jj, kk) = loss_multiclass_log_per_pixel_::label_to_ignore;
}
else if (noise_occurrence(generator)) {
ytmp(jj, kk) = noisy_label(generator);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment