Commit aa537dcc authored by Davis King's avatar Davis King

Yet more robustness tweaks for test_layer().

parent d08cd587
...@@ -199,7 +199,7 @@ namespace dlib ...@@ -199,7 +199,7 @@ namespace dlib
auto call_layer_forward( auto call_layer_forward(
layer_type& layer, layer_type& layer,
const SUBNET& sub, const SUBNET& sub,
tensor& data_output tensor& /*data_output*/
) -> decltype(layer.forward(sub,rt())) ) -> decltype(layer.forward(sub,rt()))
{ {
// This overload of call_layer_forward() is here because this template // This overload of call_layer_forward() is here because this template
...@@ -1895,8 +1895,10 @@ namespace dlib ...@@ -1895,8 +1895,10 @@ namespace dlib
const float base_eps = 0.01; const float base_eps = 0.01;
using namespace timpl; using namespace timpl;
// Do some setup // Do some setup
running_stats<double> rs_data, rs_params;
dlib::rand rnd; dlib::rand rnd;
for (int iter = 0; iter < 5; ++iter) std::ostringstream sout;
for (int iter = 0; iter < 10; ++iter)
{ {
test_layer_subnet subnetwork(rnd); test_layer_subnet subnetwork(rnd);
resizable_tensor output, out2, out3; resizable_tensor output, out2, out3;
...@@ -1911,7 +1913,6 @@ namespace dlib ...@@ -1911,7 +1913,6 @@ namespace dlib
input_grad.copy_size(output); input_grad.copy_size(output);
fill_with_gassuan_random_numbers(input_grad, rnd); fill_with_gassuan_random_numbers(input_grad, rnd);
std::ostringstream sout;
// The f() we are computing gradients of is this thing. It's value at the current // The f() we are computing gradients of is this thing. It's value at the current
// parameter and data values is: // parameter and data values is:
...@@ -2020,7 +2021,8 @@ namespace dlib ...@@ -2020,7 +2021,8 @@ namespace dlib
double output_derivative = params_grad.host()[i]; double output_derivative = params_grad.host()[i];
double relative_error = (reference_derivative - output_derivative)/(reference_derivative + 1e-100); double relative_error = (reference_derivative - output_derivative)/(reference_derivative + 1e-100);
double absolute_error = (reference_derivative - output_derivative); double absolute_error = (reference_derivative - output_derivative);
if (std::abs(relative_error) > 0.02 && std::abs(absolute_error) > 0.003) rs_params.add(std::abs(relative_error));
if (std::abs(relative_error) > 0.05 && std::abs(absolute_error) > 0.005)
{ {
using namespace std; using namespace std;
sout << "Gradient error in parameter #" << i <<". Relative error: "<< relative_error << endl; sout << "Gradient error in parameter #" << i <<". Relative error: "<< relative_error << endl;
...@@ -2028,7 +2030,6 @@ namespace dlib ...@@ -2028,7 +2030,6 @@ namespace dlib
sout << "output derivative: " << output_derivative << endl; sout << "output derivative: " << output_derivative << endl;
return layer_test_results(sout.str()); return layer_test_results(sout.str());
} }
} }
// ================================================================== // ==================================================================
...@@ -2053,7 +2054,8 @@ namespace dlib ...@@ -2053,7 +2054,8 @@ namespace dlib
output_derivative -= initial_gradient_input[i]; output_derivative -= initial_gradient_input[i];
double relative_error = (reference_derivative - output_derivative)/(reference_derivative + 1e-100); double relative_error = (reference_derivative - output_derivative)/(reference_derivative + 1e-100);
double absolute_error = (reference_derivative - output_derivative); double absolute_error = (reference_derivative - output_derivative);
if (std::abs(relative_error) > 0.02 && std::abs(absolute_error) > 0.003) rs_data.add(std::abs(relative_error));
if (std::abs(relative_error) > 0.05 && std::abs(absolute_error) > 0.005)
{ {
using namespace std; using namespace std;
sout << "Gradient error in data variable #" << i <<". Relative error: "<< relative_error << endl; sout << "Gradient error in data variable #" << i <<". Relative error: "<< relative_error << endl;
...@@ -2065,6 +2067,19 @@ namespace dlib ...@@ -2065,6 +2067,19 @@ namespace dlib
} // end for (int iter = 0; iter < 5; ++iter) } // end for (int iter = 0; iter < 5; ++iter)
if (rs_params.mean() > 0.003)
{
using namespace std;
sout << "Average parameter gradient error is somewhat large at: "<< rs_params.mean() << endl;
return layer_test_results(sout.str());
}
if (rs_data.mean() > 0.003)
{
using namespace std;
sout << "Average data gradient error is somewhat large at: "<< rs_data.mean() << endl;
return layer_test_results(sout.str());
}
return layer_test_results(); return layer_test_results();
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment