Commit cafa17c8 authored by Davis King's avatar Davis King

Changed the test_regression_function() and cross_validate_regression_trainer()

routines so they return both the MSE and R-squared values rather than just the
MSE.
parent e2e342aa
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <vector> #include <vector>
#include "../matrix.h" #include "../matrix.h"
#include "../statistics.h" #include "../statistics.h"
#include "cross_validate_regression_trainer_abstract.h"
namespace dlib namespace dlib
{ {
...@@ -17,7 +18,7 @@ namespace dlib ...@@ -17,7 +18,7 @@ namespace dlib
typename sample_type, typename sample_type,
typename label_type typename label_type
> >
label_type matrix<double,1,2>
test_regression_function ( test_regression_function (
const reg_funct_type& reg_funct, const reg_funct_type& reg_funct,
const std::vector<sample_type>& x_test, const std::vector<sample_type>& x_test,
...@@ -33,17 +34,22 @@ namespace dlib ...@@ -33,17 +34,22 @@ namespace dlib
<< "\n\t is_learning_problem(x_test,y_test): " << "\n\t is_learning_problem(x_test,y_test): "
<< is_learning_problem(x_test,y_test)); << is_learning_problem(x_test,y_test));
running_stats<label_type> rs; running_stats<double> rs;
running_scalar_covariance<double> rc;
for (unsigned long i = 0; i < x_test.size(); ++i) for (unsigned long i = 0; i < x_test.size(); ++i)
{ {
// compute error // compute error
label_type temp = reg_funct(x_test[i]) - y_test[i]; const double output = reg_funct(x_test[i]);
const double temp = output - y_test[i];
rs.add(temp*temp); rs.add(temp*temp);
rc.add(output, y_test[i]);
} }
return rs.mean(); matrix<double,1,2> result;
result = rs.mean(), std::pow(rc.correlation(),2);
return result;
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -53,7 +59,7 @@ namespace dlib ...@@ -53,7 +59,7 @@ namespace dlib
typename sample_type, typename sample_type,
typename label_type typename label_type
> >
label_type matrix<double,1,2>
cross_validate_regression_trainer ( cross_validate_regression_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const std::vector<sample_type>& x, const std::vector<sample_type>& x,
...@@ -78,11 +84,12 @@ namespace dlib ...@@ -78,11 +84,12 @@ namespace dlib
const long num_in_test = x.size()/folds; const long num_in_test = x.size()/folds;
const long num_in_train = x.size() - num_in_test; const long num_in_train = x.size() - num_in_test;
running_stats<double> rs;
running_scalar_covariance<double> rc;
std::vector<sample_type> x_test, x_train; std::vector<sample_type> x_test, x_train;
std::vector<label_type> y_test, y_train; std::vector<label_type> y_test, y_train;
running_stats<label_type> rs;
long next_test_idx = 0; long next_test_idx = 0;
...@@ -114,8 +121,18 @@ namespace dlib ...@@ -114,8 +121,18 @@ namespace dlib
try try
{ {
const typename trainer_type::trained_function_type& df = trainer.train(x_train,y_train);
// do the training and testing // do the training and testing
rs.add(test_regression_function(trainer.train(x_train,y_train),x_test,y_test)); for (long j = 0; j < x_test.size(); ++j)
{
// compute error
const double output = df(x_test[j]);
const double temp = output - y_test[j];
rs.add(temp*temp);
rc.add(output, y_test[j]);
}
} }
catch (invalid_nu_error&) catch (invalid_nu_error&)
{ {
...@@ -124,7 +141,9 @@ namespace dlib ...@@ -124,7 +141,9 @@ namespace dlib
} // for (long i = 0; i < folds; ++i) } // for (long i = 0; i < folds; ++i)
return rs.mean(); matrix<double,1,2> result;
result = rs.mean(), std::pow(rc.correlation(),2);
return result;
} }
} }
......
...@@ -16,7 +16,7 @@ namespace dlib ...@@ -16,7 +16,7 @@ namespace dlib
typename sample_type, typename sample_type,
typename label_type typename label_type
> >
label_type matrix<double,1,2>
test_regression_function ( test_regression_function (
const reg_funct_type& reg_funct, const reg_funct_type& reg_funct,
const std::vector<sample_type>& x_test, const std::vector<sample_type>& x_test,
...@@ -29,9 +29,12 @@ namespace dlib ...@@ -29,9 +29,12 @@ namespace dlib
(e.g. a decision_function created by the svr_trainer ) (e.g. a decision_function created by the svr_trainer )
ensures ensures
- Tests reg_funct against the given samples in x_test and target values in - Tests reg_funct against the given samples in x_test and target values in
y_test and returns the mean squared error. Specifically, the MSE is given y_test and returns a matrix M summarizing the results. Specifically:
by: - M(0) == the mean squared error.
sum over i: pow(reg_funct(x_test[i]) - y_test[i], 2.0) The MSE is given by: sum over i: pow(reg_funct(x_test[i]) - y_test[i], 2.0)
- M(1) == the R-squared value (i.e. the squared correlation between
reg_funct(x_test[i]) and y_test[i]). This is a number between 0
and 1.
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -41,7 +44,7 @@ namespace dlib ...@@ -41,7 +44,7 @@ namespace dlib
typename sample_type, typename sample_type,
typename label_type typename label_type
> >
label_type matrix<double,1,2>
cross_validate_regression_trainer ( cross_validate_regression_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const std::vector<sample_type>& x, const std::vector<sample_type>& x,
...@@ -54,11 +57,15 @@ namespace dlib ...@@ -54,11 +57,15 @@ namespace dlib
- 1 < folds <= x.size() - 1 < folds <= x.size()
- trainer_type == some kind of regression trainer object (e.g. svr_trainer) - trainer_type == some kind of regression trainer object (e.g. svr_trainer)
ensures ensures
- performs k-fold cross validation by using the given trainer to solve the - Performs k-fold cross validation by using the given trainer to solve a
given regression problem for the given number of folds. Each fold is tested using regression problem for the given number of folds. Each fold is tested using
the output of the trainer and the mean squared error is computed and returned. the output of the trainer. A matrix M summarizing the results is returned.
- The total MSE is computed by running test_binary_decision_function() Specifically:
on each fold and averaging its output. - M(0) == the mean squared error.
The MSE is given by: sum over i: pow(reg_funct(x[i]) - y[i], 2.0)
- M(1) == the R-squared value (i.e. the squared correlation between
a predicted y value and its true value). This is a number between
0 and 1.
!*/ !*/
} }
......
...@@ -241,10 +241,14 @@ namespace ...@@ -241,10 +241,14 @@ namespace
randomize_samples(samples, labels); randomize_samples(samples, labels);
dlog << LINFO << "KRR MSE: "<< cross_validate_regression_trainer(krr_test, samples, labels, 6); dlog << LINFO << "KRR MSE and R-squared: "<< cross_validate_regression_trainer(krr_test, samples, labels, 6);
dlog << LINFO << "SVR MSE: "<< cross_validate_regression_trainer(svr_test, samples, labels, 6); dlog << LINFO << "SVR MSE and R-squared: "<< cross_validate_regression_trainer(svr_test, samples, labels, 6);
DLIB_TEST(cross_validate_regression_trainer(krr_test, samples, labels, 6) < 1e-4); matrix<double,1,2> cv = cross_validate_regression_trainer(krr_test, samples, labels, 6);
DLIB_TEST(cross_validate_regression_trainer(svr_test, samples, labels, 6) < 1e-4); DLIB_TEST(cv(0) < 1e-4);
DLIB_TEST(cv(1) > 0.99);
cv = cross_validate_regression_trainer(svr_test, samples, labels, 6);
DLIB_TEST(cv(0) < 1e-4);
DLIB_TEST(cv(1) > 0.99);
dlog << LINFO << " end test_regression()"; dlog << LINFO << " end test_regression()";
} }
......
...@@ -84,13 +84,13 @@ int main() ...@@ -84,13 +84,13 @@ int main()
// The first column is the true value of the sinc function and the second // The first column is the true value of the sinc function and the second
// column is the output from the SVR estimate. // column is the output from the SVR estimate.
// We can also do 5-fold cross-validation and find the mean squared error. Note that // We can also do 5-fold cross-validation and find the mean squared error and R-squared
// we need to randomly shuffle the samples first. See the svm_ex.cpp for a discussion of // values. Note that we need to randomly shuffle the samples first. See the svm_ex.cpp
// why this is important. // for a discussion of why this is important.
randomize_samples(samples, targets); randomize_samples(samples, targets);
cout << "MSE: "<< cross_validate_regression_trainer(trainer, samples, targets, 5) << endl; cout << "MSE and R-Squared: "<< cross_validate_regression_trainer(trainer, samples, targets, 5) << endl;
// The output is: // The output is:
// MSE: 1.65984e-05 // MSE and R-Squared: 1.65984e-05 0.999901
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment