Commit fc927746 authored by Juha Reunanen's avatar Juha Reunanen Committed by Davis E. King

Add per-pixel mean square loss (#690)

* Add per-pixel mean square loss

* Add documentation of loss_mean_squared_per_pixel_

* Add test case for per-pixel mean square loss: a simple autoencoder

* Review fix: reorder params of function tensor_index, so that the order corresponds to the convention used in the rest of the dlib code base

* Review fix: add breaks as intended, and change the rest of the test accordingly

* Again a case where the tests already work locally for me, but not on AppVeyor/Travis - this commit is a blindfolded attempt to fix the problem
(and it also fixes a compiler warning)
parent daedd901
...@@ -1569,10 +1569,10 @@ namespace dlib ...@@ -1569,10 +1569,10 @@ namespace dlib
const auto find_label = [&](long sample, long r, long c) const auto find_label = [&](long sample, long r, long c)
{ {
uint16_t label = 0; uint16_t label = 0;
float max_value = out_data[tensor_index(output_tensor, sample, r, c, 0)]; float max_value = out_data[tensor_index(output_tensor, sample, 0, r, c)];
for (long k = 1; k < output_tensor.k(); ++k) for (long k = 1; k < output_tensor.k(); ++k)
{ {
const float value = out_data[tensor_index(output_tensor, sample, r, c, k)]; const float value = out_data[tensor_index(output_tensor, sample, k, r, c)];
if (value > max_value) if (value > max_value)
{ {
label = static_cast<uint16_t>(k); label = static_cast<uint16_t>(k);
...@@ -1647,7 +1647,7 @@ namespace dlib ...@@ -1647,7 +1647,7 @@ namespace dlib
"y: " << y << ", output_tensor.k(): " << output_tensor.k()); "y: " << y << ", output_tensor.k(): " << output_tensor.k());
for (long k = 0; k < output_tensor.k(); ++k) for (long k = 0; k < output_tensor.k(); ++k)
{ {
const size_t idx = tensor_index(output_tensor, i, r, c, k); const size_t idx = tensor_index(output_tensor, i, k, r, c);
if (k == y) if (k == y)
{ {
loss += scale*-std::log(g[idx]); loss += scale*-std::log(g[idx]);
...@@ -1693,7 +1693,7 @@ namespace dlib ...@@ -1693,7 +1693,7 @@ namespace dlib
} }
private: private:
static size_t tensor_index(const tensor& t, long sample, long row, long column, long k) static size_t tensor_index(const tensor& t, long sample, long k, long row, long column)
{ {
// See: https://github.com/davisking/dlib/blob/4dfeb7e186dd1bf6ac91273509f687293bd4230a/dlib/dnn/tensor_abstract.h#L38 // See: https://github.com/davisking/dlib/blob/4dfeb7e186dd1bf6ac91273509f687293bd4230a/dlib/dnn/tensor_abstract.h#L38
return ((sample * t.k() + k) * t.nr() + row) * t.nc() + column; return ((sample * t.k() + k) * t.nr() + row) * t.nc() + column;
...@@ -1793,7 +1793,7 @@ namespace dlib ...@@ -1793,7 +1793,7 @@ namespace dlib
"y: " << y << ", output_tensor.k(): " << output_tensor.k()); "y: " << y << ", output_tensor.k(): " << output_tensor.k());
for (long k = 0; k < output_tensor.k(); ++k) for (long k = 0; k < output_tensor.k(); ++k)
{ {
const size_t idx = tensor_index(output_tensor, i, r, c, k); const size_t idx = tensor_index(output_tensor, i, k, r, c);
if (k == y) if (k == y)
{ {
loss += weight*scale*-std::log(g[idx]); loss += weight*scale*-std::log(g[idx]);
...@@ -1835,7 +1835,7 @@ namespace dlib ...@@ -1835,7 +1835,7 @@ namespace dlib
} }
private: private:
static size_t tensor_index(const tensor& t, long sample, long row, long column, long k) static size_t tensor_index(const tensor& t, long sample, long k, long row, long column)
{ {
// See: https://github.com/davisking/dlib/blob/4dfeb7e186dd1bf6ac91273509f687293bd4230a/dlib/dnn/tensor_abstract.h#L38 // See: https://github.com/davisking/dlib/blob/4dfeb7e186dd1bf6ac91273509f687293bd4230a/dlib/dnn/tensor_abstract.h#L38
return ((sample * t.k() + k) * t.nr() + row) * t.nc() + column; return ((sample * t.k() + k) * t.nr() + row) * t.nc() + column;
...@@ -1848,6 +1848,138 @@ namespace dlib ...@@ -1848,6 +1848,138 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class loss_mean_squared_per_pixel_
{
public:
typedef matrix<float> training_label_type;
typedef matrix<float> output_label_type;
template <
typename SUB_TYPE,
typename label_iterator
>
void to_label (
const tensor& input_tensor,
const SUB_TYPE& sub,
label_iterator iter
) const
{
DLIB_CASSERT(sub.sample_expansion_factor() == 1);
const tensor& output_tensor = sub.get_output();
DLIB_CASSERT(output_tensor.k() == 1, "output k = " << output_tensor.k());
DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
const float* out_data = output_tensor.host();
for (long i = 0; i < output_tensor.num_samples(); ++i, ++iter)
{
iter->set_size(output_tensor.nr(), output_tensor.nc());
for (long r = 0; r < output_tensor.nr(); ++r)
{
for (long c = 0; c < output_tensor.nc(); ++c)
{
iter->operator()(r, c) = out_data[tensor_index(output_tensor, i, 0, r, c)];
}
}
}
}
template <
typename const_label_iterator,
typename SUBNET
>
double compute_loss_value_and_gradient (
const tensor& input_tensor,
const_label_iterator truth,
SUBNET& sub
) const
{
const tensor& output_tensor = sub.get_output();
tensor& grad = sub.get_gradient_input();
DLIB_CASSERT(sub.sample_expansion_factor() == 1);
DLIB_CASSERT(input_tensor.num_samples() != 0);
DLIB_CASSERT(input_tensor.num_samples() % sub.sample_expansion_factor() == 0);
DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
DLIB_CASSERT(output_tensor.k() >= 1);
DLIB_CASSERT(output_tensor.k() < std::numeric_limits<uint16_t>::max());
DLIB_CASSERT(output_tensor.nr() == grad.nr() &&
output_tensor.nc() == grad.nc() &&
output_tensor.k() == grad.k());
for (long idx = 0; idx < output_tensor.num_samples(); ++idx)
{
const_label_iterator truth_matrix_ptr = (truth + idx);
DLIB_CASSERT(truth_matrix_ptr->nr() == output_tensor.nr() &&
truth_matrix_ptr->nc() == output_tensor.nc(),
"truth size = " << truth_matrix_ptr->nr() << " x " << truth_matrix_ptr->nc() << ", "
"output size = " << output_tensor.nr() << " x " << output_tensor.nc());
}
// The loss we output is the average loss over the mini-batch, and also over each element of the matrix output.
const double scale = 1.0 / (output_tensor.num_samples() * output_tensor.nr() * output_tensor.nc());
double loss = 0;
float* const g = grad.host();
const float* out_data = output_tensor.host();
for (long i = 0; i < output_tensor.num_samples(); ++i, ++truth)
{
for (long r = 0; r < output_tensor.nr(); ++r)
{
for (long c = 0; c < output_tensor.nc(); ++c)
{
const float y = truth->operator()(r, c);
const size_t idx = tensor_index(output_tensor, i, 0, r, c);
const float temp1 = y - out_data[idx];
const float temp2 = scale*temp1;
loss += 0.5*temp2*temp1;
g[idx] = -temp2;
}
}
}
return loss;
}
friend void serialize(const loss_mean_squared_per_pixel_& , std::ostream& out)
{
serialize("loss_mean_squared_per_pixel_", out);
}
friend void deserialize(loss_mean_squared_per_pixel_& , std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "loss_mean_squared_per_pixel_")
throw serialization_error("Unexpected version found while deserializing dlib::loss_mean_squared_per_pixel_.");
}
friend std::ostream& operator<<(std::ostream& out, const loss_mean_squared_per_pixel_& )
{
out << "loss_mean_squared_per_pixel";
return out;
}
friend void to_xml(const loss_mean_squared_per_pixel_& /*item*/, std::ostream& out)
{
out << "<loss_mean_squared_per_pixel/>";
}
private:
static size_t tensor_index(const tensor& t, long sample, long k, long row, long column)
{
// See: https://github.com/davisking/dlib/blob/4dfeb7e186dd1bf6ac91273509f687293bd4230a/dlib/dnn/tensor_abstract.h#L38
return ((sample * t.k() + k) * t.nr() + row) * t.nc() + column;
}
};
template <typename SUBNET>
using loss_mean_squared_per_pixel = add_loss_layer<loss_mean_squared_per_pixel_, SUBNET>;
// ----------------------------------------------------------------------------------------
} }
#endif // DLIB_DNn_LOSS_H_ #endif // DLIB_DNn_LOSS_H_
......
...@@ -951,6 +951,64 @@ namespace dlib ...@@ -951,6 +951,64 @@ namespace dlib
template <typename SUBNET> template <typename SUBNET>
using loss_multiclass_log_per_pixel_weighted = add_loss_layer<loss_multiclass_log_per_pixel_weighted_, SUBNET>; using loss_multiclass_log_per_pixel_weighted = add_loss_layer<loss_multiclass_log_per_pixel_weighted_, SUBNET>;
// ----------------------------------------------------------------------------------------
class loss_mean_squared_per_pixel_
{
/*!
WHAT THIS OBJECT REPRESENTS
This object implements the loss layer interface defined above by
EXAMPLE_LOSS_LAYER_. In particular, it implements the mean squared loss,
which is appropriate for regression problems. It is basically just like
loss_mean_squared_multioutput_ except that it lets you define matrix or
image outputs, instead of vector.
!*/
public:
typedef matrix<float> training_label_type;
typedef matrix<float> output_label_type;
template <
typename SUB_TYPE,
typename label_iterator
>
void to_label (
const tensor& input_tensor,
const SUB_TYPE& sub,
label_iterator iter
) const;
/*!
This function has the same interface as EXAMPLE_LOSS_LAYER_::to_label() except
it has the additional calling requirements that:
- sub.get_output().num_samples() == input_tensor.num_samples()
- sub.sample_expansion_factor() == 1
and the output labels are the predicted continuous variables.
!*/
template <
typename const_label_iterator,
typename SUBNET
>
double compute_loss_value_and_gradient (
const tensor& input_tensor,
const_label_iterator truth,
SUBNET& sub
) const;
/*!
This function has the same interface as EXAMPLE_LOSS_LAYER_::compute_loss_value_and_gradient()
except it has the additional calling requirements that:
- sub.get_output().k() == 1
- sub.get_output().num_samples() == input_tensor.num_samples()
- sub.sample_expansion_factor() == 1
- for all idx such that 0 <= idx < sub.get_output().num_samples():
- sub.get_output().nr() == (*(truth + idx)).nr()
- sub.get_output().nc() == (*(truth + idx)).nc()
!*/
};
template <typename SUBNET>
using loss_mean_squared_per_pixel = add_loss_layer<loss_mean_squared_per_pixel_, SUBNET>;
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
...@@ -1995,6 +1995,80 @@ namespace ...@@ -1995,6 +1995,80 @@ namespace
} }
// ----------------------------------------------------------------------------------------
void test_simple_autoencoder()
{
print_spinner();
const int output_width = 7;
const int output_height = 7;
const int num_samples = 100;
::std::vector<matrix<float>> x(num_samples);
matrix<float> tmp(output_width, output_height);
for (int i = 0; i < num_samples; ++i)
{
const int model = i % 4;
for (int r = 0; r < output_height; ++r)
for (int c = 0; c < output_width; ++c)
switch (model) {
case 0: tmp(r, c) = r / output_height; break;
case 1: tmp(r, c) = c / output_width; break;
case 2: tmp(r, c) = 1.0 - r / output_height; break;
case 3: tmp(r, c) = 1.0 - c / output_width; break;
default: DLIB_TEST_MSG(false, "Invalid model: " << model << " (should be between 0 and 3)");
}
x[i] = tmp;
}
using net_type = loss_mean_squared_per_pixel<
cont<1,output_height,output_width,2,2,
relu<con<4,output_height,output_width,2,2,
input<matrix<float>>>>>>;
net_type net;
const auto autoencoder_error = [&x, &net, &output_height, &output_width]()
{
const auto y = net(x);
double error = 0.0;
for (size_t i = 0; i < x.size(); ++i)
for (int r = 0; r < output_height; ++r)
for (int c = 0; c < output_width; ++c)
error += fabs(y[i](r, c) - x[i](r, c));
return error / (x.size() * output_height * output_width);
};
// The autoencoder can't be very good before it's been trained
// (or at least the probability of the reconstruction error
// being small should be super low; in fact, the error ought to
// be much higher than 0.01, however since the initialization
// is random, putting the limit below too high could make the
// tests fail when other, unrelated tests are added into the
// sequence)
const double error_before = autoencoder_error();
DLIB_TEST_MSG(error_before > 0.01, "Autoencoder error before training = " << error_before);
// Make sure there's an information bottleneck, as intended
const auto& output2 = dlib::layer<2>(net).get_output();
DLIB_TEST(output2.nr() == 1);
DLIB_TEST(output2.nc() == 1);
DLIB_TEST(output2.k() == 4);
sgd defsolver(0,0.9);
dnn_trainer<net_type> trainer(net, defsolver);
trainer.set_learning_rate(0.01);
trainer.set_max_num_epochs(1000);
trainer.train(x, x);
// Now we should have learned everything there is to it
const double error_after = autoencoder_error();
DLIB_TEST_MSG(error_after < 1e-6, "Autoencoder error after training = " << error_after);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void test_loss_multiclass_per_pixel_learned_params_on_trivial_single_pixel_task() void test_loss_multiclass_per_pixel_learned_params_on_trivial_single_pixel_task()
...@@ -2574,6 +2648,7 @@ namespace ...@@ -2574,6 +2648,7 @@ namespace
test_concat(); test_concat();
test_simple_linear_regression(); test_simple_linear_regression();
test_multioutput_linear_regression(); test_multioutput_linear_regression();
test_simple_autoencoder();
test_loss_multiclass_per_pixel_learned_params_on_trivial_single_pixel_task(); test_loss_multiclass_per_pixel_learned_params_on_trivial_single_pixel_task();
test_loss_multiclass_per_pixel_activations_on_trivial_single_pixel_task(); test_loss_multiclass_per_pixel_activations_on_trivial_single_pixel_task();
test_loss_multiclass_per_pixel_outputs_on_trivial_task(); test_loss_multiclass_per_pixel_outputs_on_trivial_task();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment