Commit cafa17c8 authored by Davis King's avatar Davis King

Changed the test_regression_function() and cross_validate_regression_trainer()

routines so they return both the MSE and R-squared values rather than just the
MSE.
parent e2e342aa
......@@ -6,6 +6,7 @@
#include <vector>
#include "../matrix.h"
#include "../statistics.h"
#include "cross_validate_regression_trainer_abstract.h"
namespace dlib
{
......@@ -17,7 +18,7 @@ namespace dlib
typename sample_type,
typename label_type
>
label_type
matrix<double,1,2>
test_regression_function (
const reg_funct_type& reg_funct,
const std::vector<sample_type>& x_test,
......@@ -33,17 +34,22 @@ namespace dlib
<< "\n\t is_learning_problem(x_test,y_test): "
<< is_learning_problem(x_test,y_test));
running_stats<label_type> rs;
running_stats<double> rs;
running_scalar_covariance<double> rc;
for (unsigned long i = 0; i < x_test.size(); ++i)
{
// compute error
label_type temp = reg_funct(x_test[i]) - y_test[i];
const double output = reg_funct(x_test[i]);
const double temp = output - y_test[i];
rs.add(temp*temp);
rc.add(output, y_test[i]);
}
return rs.mean();
matrix<double,1,2> result;
result = rs.mean(), std::pow(rc.correlation(),2);
return result;
}
// ----------------------------------------------------------------------------------------
......@@ -53,7 +59,7 @@ namespace dlib
typename sample_type,
typename label_type
>
label_type
matrix<double,1,2>
cross_validate_regression_trainer (
const trainer_type& trainer,
const std::vector<sample_type>& x,
......@@ -78,11 +84,12 @@ namespace dlib
const long num_in_test = x.size()/folds;
const long num_in_train = x.size() - num_in_test;
running_stats<double> rs;
running_scalar_covariance<double> rc;
std::vector<sample_type> x_test, x_train;
std::vector<label_type> y_test, y_train;
running_stats<label_type> rs;
long next_test_idx = 0;
......@@ -114,8 +121,18 @@ namespace dlib
try
{
const typename trainer_type::trained_function_type& df = trainer.train(x_train,y_train);
// do the training and testing
rs.add(test_regression_function(trainer.train(x_train,y_train),x_test,y_test));
for (long j = 0; j < x_test.size(); ++j)
{
// compute error
const double output = df(x_test[j]);
const double temp = output - y_test[j];
rs.add(temp*temp);
rc.add(output, y_test[j]);
}
}
catch (invalid_nu_error&)
{
......@@ -124,7 +141,9 @@ namespace dlib
} // for (long i = 0; i < folds; ++i)
return rs.mean();
matrix<double,1,2> result;
result = rs.mean(), std::pow(rc.correlation(),2);
return result;
}
}
......
......@@ -16,7 +16,7 @@ namespace dlib
typename sample_type,
typename label_type
>
label_type
matrix<double,1,2>
test_regression_function (
const reg_funct_type& reg_funct,
const std::vector<sample_type>& x_test,
......@@ -29,9 +29,12 @@ namespace dlib
(e.g. a decision_function created by the svr_trainer )
ensures
- Tests reg_funct against the given samples in x_test and target values in
y_test and returns the mean squared error. Specifically, the MSE is given
by:
sum over i: pow(reg_funct(x_test[i]) - y_test[i], 2.0)
y_test and returns a matrix M summarizing the results. Specifically:
- M(0) == the mean squared error.
The MSE is given by: sum over i: pow(reg_funct(x_test[i]) - y_test[i], 2.0)
- M(1) == the R-squared value (i.e. the squared correlation between
reg_funct(x_test[i]) and y_test[i]). This is a number between 0
and 1.
!*/
// ----------------------------------------------------------------------------------------
......@@ -41,7 +44,7 @@ namespace dlib
typename sample_type,
typename label_type
>
label_type
matrix<double,1,2>
cross_validate_regression_trainer (
const trainer_type& trainer,
const std::vector<sample_type>& x,
......@@ -54,11 +57,15 @@ namespace dlib
- 1 < folds <= x.size()
- trainer_type == some kind of regression trainer object (e.g. svr_trainer)
ensures
- performs k-fold cross validation by using the given trainer to solve the
given regression problem for the given number of folds. Each fold is tested using
the output of the trainer and the mean squared error is computed and returned.
- The total MSE is computed by running test_binary_decision_function()
on each fold and averaging its output.
- Performs k-fold cross validation by using the given trainer to solve a
regression problem for the given number of folds. Each fold is tested using
the output of the trainer. A matrix M summarizing the results is returned.
Specifically:
- M(0) == the mean squared error.
The MSE is given by: sum over i: pow(reg_funct(x[i]) - y[i], 2.0)
- M(1) == the R-squared value (i.e. the squared correlation between
a predicted y value and its true value). This is a number between
0 and 1.
!*/
}
......
......@@ -241,10 +241,14 @@ namespace
randomize_samples(samples, labels);
dlog << LINFO << "KRR MSE: "<< cross_validate_regression_trainer(krr_test, samples, labels, 6);
dlog << LINFO << "SVR MSE: "<< cross_validate_regression_trainer(svr_test, samples, labels, 6);
DLIB_TEST(cross_validate_regression_trainer(krr_test, samples, labels, 6) < 1e-4);
DLIB_TEST(cross_validate_regression_trainer(svr_test, samples, labels, 6) < 1e-4);
dlog << LINFO << "KRR MSE and R-squared: "<< cross_validate_regression_trainer(krr_test, samples, labels, 6);
dlog << LINFO << "SVR MSE and R-squared: "<< cross_validate_regression_trainer(svr_test, samples, labels, 6);
matrix<double,1,2> cv = cross_validate_regression_trainer(krr_test, samples, labels, 6);
DLIB_TEST(cv(0) < 1e-4);
DLIB_TEST(cv(1) > 0.99);
cv = cross_validate_regression_trainer(svr_test, samples, labels, 6);
DLIB_TEST(cv(0) < 1e-4);
DLIB_TEST(cv(1) > 0.99);
dlog << LINFO << " end test_regression()";
}
......
......@@ -84,13 +84,13 @@ int main()
// The first column is the true value of the sinc function and the second
// column is the output from the SVR estimate.
// We can also do 5-fold cross-validation and find the mean squared error. Note that
// we need to randomly shuffle the samples first. See the svm_ex.cpp for a discussion of
// why this is important.
// We can also do 5-fold cross-validation and find the mean squared error and R-squared
// values. Note that we need to randomly shuffle the samples first. See the svm_ex.cpp
// for a discussion of why this is important.
randomize_samples(samples, targets);
cout << "MSE: "<< cross_validate_regression_trainer(trainer, samples, targets, 5) << endl;
cout << "MSE and R-Squared: "<< cross_validate_regression_trainer(trainer, samples, targets, 5) << endl;
// The output is:
// MSE: 1.65984e-05
// MSE and R-Squared: 1.65984e-05 0.999901
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment