Commit 6137540b authored by Davis King's avatar Davis King

Changed test_regression_function() and cross_validate_regression_trainer() to

output 2 more statistics, which are the mean absolute error and the standard
deviation of the absolute error.  This means these functions now return 4D
rather than 2D vectors.

I also made test_regression_function() take a non-const reference to the
regression function so that DNN objects can be tested.
parent 5a0c09c7
...@@ -18,9 +18,9 @@ namespace dlib ...@@ -18,9 +18,9 @@ namespace dlib
typename sample_type, typename sample_type,
typename label_type typename label_type
> >
matrix<double,1,2> matrix<double,1,4>
test_regression_function ( test_regression_function (
const reg_funct_type& reg_funct, reg_funct_type& reg_funct,
const std::vector<sample_type>& x_test, const std::vector<sample_type>& x_test,
const std::vector<label_type>& y_test const std::vector<label_type>& y_test
) )
...@@ -33,7 +33,7 @@ namespace dlib ...@@ -33,7 +33,7 @@ namespace dlib
<< "\n\t is_learning_problem(x_test,y_test): " << "\n\t is_learning_problem(x_test,y_test): "
<< is_learning_problem(x_test,y_test)); << is_learning_problem(x_test,y_test));
running_stats<double> rs; running_stats<double> rs, rs_mae;
running_scalar_covariance<double> rc; running_scalar_covariance<double> rc;
for (unsigned long i = 0; i < x_test.size(); ++i) for (unsigned long i = 0; i < x_test.size(); ++i)
...@@ -42,12 +42,13 @@ namespace dlib ...@@ -42,12 +42,13 @@ namespace dlib
const double output = reg_funct(x_test[i]); const double output = reg_funct(x_test[i]);
const double temp = output - y_test[i]; const double temp = output - y_test[i];
rs_mae.add(std::abs(temp));
rs.add(temp*temp); rs.add(temp*temp);
rc.add(output, y_test[i]); rc.add(output, y_test[i]);
} }
matrix<double,1,2> result; matrix<double,1,4> result;
result = rs.mean(), std::pow(rc.correlation(),2); result = rs.mean(), std::pow(rc.correlation(),2), rs_mae.mean(), rs_mae.stddev();
return result; return result;
} }
...@@ -58,7 +59,7 @@ namespace dlib ...@@ -58,7 +59,7 @@ namespace dlib
typename sample_type, typename sample_type,
typename label_type typename label_type
> >
matrix<double,1,2> matrix<double,1,4>
cross_validate_regression_trainer ( cross_validate_regression_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const std::vector<sample_type>& x, const std::vector<sample_type>& x,
...@@ -82,7 +83,7 @@ namespace dlib ...@@ -82,7 +83,7 @@ namespace dlib
const long num_in_test = x.size()/folds; const long num_in_test = x.size()/folds;
const long num_in_train = x.size() - num_in_test; const long num_in_train = x.size() - num_in_test;
running_stats<double> rs; running_stats<double> rs, rs_mae;
running_scalar_covariance<double> rc; running_scalar_covariance<double> rc;
std::vector<sample_type> x_test, x_train; std::vector<sample_type> x_test, x_train;
...@@ -128,6 +129,7 @@ namespace dlib ...@@ -128,6 +129,7 @@ namespace dlib
const double output = df(x_test[j]); const double output = df(x_test[j]);
const double temp = output - y_test[j]; const double temp = output - y_test[j];
rs_mae.add(std::abs(temp));
rs.add(temp*temp); rs.add(temp*temp);
rc.add(output, y_test[j]); rc.add(output, y_test[j]);
} }
...@@ -139,8 +141,8 @@ namespace dlib ...@@ -139,8 +141,8 @@ namespace dlib
} // for (long i = 0; i < folds; ++i) } // for (long i = 0; i < folds; ++i)
matrix<double,1,2> result; matrix<double,1,4> result;
result = rs.mean(), std::pow(rc.correlation(),2); result = rs.mean(), std::pow(rc.correlation(),2), rs_mae.mean(), rs_mae.stddev();
return result; return result;
} }
......
...@@ -16,9 +16,9 @@ namespace dlib ...@@ -16,9 +16,9 @@ namespace dlib
typename sample_type, typename sample_type,
typename label_type typename label_type
> >
matrix<double,1,2> matrix<double,1,4>
test_regression_function ( test_regression_function (
const reg_funct_type& reg_funct, reg_funct_type& reg_funct,
const std::vector<sample_type>& x_test, const std::vector<sample_type>& x_test,
const std::vector<label_type>& y_test const std::vector<label_type>& y_test
); );
...@@ -35,6 +35,9 @@ namespace dlib ...@@ -35,6 +35,9 @@ namespace dlib
- M(1) == the R-squared value (i.e. the squared correlation between - M(1) == the R-squared value (i.e. the squared correlation between
reg_funct(x_test[i]) and y_test[i]). This is a number between 0 reg_funct(x_test[i]) and y_test[i]). This is a number between 0
and 1. and 1.
- M(2) == the mean absolute error.
This is given by: sum over i: abs(reg_funct(x_test[i]) - y_test[i])
- M(3) == the standard deviation of the absolute error.
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -44,7 +47,7 @@ namespace dlib ...@@ -44,7 +47,7 @@ namespace dlib
typename sample_type, typename sample_type,
typename label_type typename label_type
> >
matrix<double,1,2> matrix<double,1,4>
cross_validate_regression_trainer ( cross_validate_regression_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const std::vector<sample_type>& x, const std::vector<sample_type>& x,
...@@ -66,6 +69,9 @@ namespace dlib ...@@ -66,6 +69,9 @@ namespace dlib
- M(1) == the R-squared value (i.e. the squared correlation between - M(1) == the R-squared value (i.e. the squared correlation between
a predicted y value and its true value). This is a number between a predicted y value and its true value). This is a number between
0 and 1. 0 and 1.
- M(2) == the mean absolute error.
This is given by: sum over i: abs(reg_funct(x_test[i]) - y_test[i])
- M(3) == the standard deviation of the absolute error.
!*/ !*/
} }
......
...@@ -247,7 +247,7 @@ namespace ...@@ -247,7 +247,7 @@ namespace
randomize_samples(samples, labels); randomize_samples(samples, labels);
dlog << LINFO << "KRR MSE and R-squared: "<< cross_validate_regression_trainer(krr_test, samples, labels, 6); dlog << LINFO << "KRR MSE and R-squared: "<< cross_validate_regression_trainer(krr_test, samples, labels, 6);
dlog << LINFO << "SVR MSE and R-squared: "<< cross_validate_regression_trainer(svr_test, samples, labels, 6); dlog << LINFO << "SVR MSE and R-squared: "<< cross_validate_regression_trainer(svr_test, samples, labels, 6);
matrix<double,1,2> cv = cross_validate_regression_trainer(krr_test, samples, labels, 6); matrix<double,1,4> cv = cross_validate_regression_trainer(krr_test, samples, labels, 6);
DLIB_TEST(cv(0) < 1e-4); DLIB_TEST(cv(0) < 1e-4);
DLIB_TEST(cv(1) > 0.99); DLIB_TEST(cv(1) > 0.99);
cv = cross_validate_regression_trainer(svr_test, samples, labels, 6); cv = cross_validate_regression_trainer(svr_test, samples, labels, 6);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment