Commit 2ecf811d authored by Davis King's avatar Davis King

Fixed a minor bug in how the cross validation accuracy was being

computed.
parent 94d7305a
......@@ -114,13 +114,14 @@ namespace dlib
const long num_in_test = samples.size()/folds;
const long num_in_train = samples.size() - num_in_test;
running_stats<double> rs;
std::vector<sample_type> samples_test, samples_train;
std::vector<label_type> labels_test, labels_train;
long next_test_idx = 0;
double total_right = 0;
double total = 0;
for (long i = 0; i < folds; ++i)
......@@ -148,13 +149,28 @@ namespace dlib
}
rs.add(test_assignment_function(trainer.train(samples_train,labels_train),
samples_test,
labels_test));
const typename trainer_type::trained_function_type& df = trainer.train(samples_train,labels_train);
// check how good df is on the test data
for (unsigned long i = 0; i < samples_test.size(); ++i)
{
const std::vector<long>& out = df(samples_test[i]);
for (unsigned long j = 0; j < out.size(); ++j)
{
if (out[j] == labels_test[i][j])
++total_right;
++total;
}
}
} // for (long i = 0; i < folds; ++i)
return rs.mean();
if (total != 0)
return total_right/total;
else
return 1;
}
// ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment