Commit 07e48859 authored by Davis King's avatar Davis King

I needed to make a few minor changes to make this code work with the new

version of dlib.
parent e3310d45
......@@ -194,14 +194,15 @@ krr_rbk_test (
gamma = gamma_range.get_next_value (gamma))
{
// LOO cross validation
double loo_error;
std::vector<double> loo_values;
if (parser.option("verbose")) {
trainer.set_search_lambdas(logspace(-9, 4, 100));
trainer.be_verbose();
}
trainer.set_kernel (kernel_type (gamma));
trainer.train (dense_samples, labels, loo_error);
trainer.train (dense_samples, labels, loo_values);
const double loo_error = mean_squared_error(loo_values, labels);
if (loo_error < best_loo) {
best_loo = loo_error;
best_gamma = gamma;
......@@ -237,9 +238,12 @@ krr_lin_test (
krr_trainer<kernel_type> trainer;
// LOO cross validation
double loo_error;
trainer.train(dense_samples, labels, loo_error);
std::vector<double> loo_values;
trainer.train(dense_samples, labels, loo_values);
const double loo_error = mean_squared_error(loo_values, labels);
const double rs = r_squared(loo_values, labels);
std::cout << "mean squared LOO error: " << loo_error << std::endl;
std::cout << "R-Squared LOO: " << rs << std::endl;
}
// ----------------------------------------------------------------------------------------
......@@ -343,11 +347,11 @@ svr_test (
gamma = gamma_range.get_next_value (gamma))
{
cout << "test with svr-C: " << svr_c << " gamma: "<< gamma << flush;
double cv_error;
matrix<double,1,2> cv;
trainer.set_kernel (kernel_type (gamma));
cv_error = cross_validate_regression_trainer (trainer,
dense_samples, labels, 10);
cout << " 10-fold-MSE: "<< cv_error << endl;
cv = cross_validate_regression_trainer (trainer, dense_samples, labels, 10);
cout << " 10-fold-MSE: "<< cv(0) << endl;
cout << " 10-fold-R-Squared: "<< cv(1) << endl;
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment