Commit fd751829 authored by Davis King's avatar Davis King

Added a test to make sure the probabilistic() trainer adapter works right

and also that it works with the one_vs_all_trainer without issue.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404076
parent 4805117a
......@@ -171,6 +171,82 @@ namespace
DLIB_TEST(res == ans);
}
template <typename label_type, typename scalar_type>
void run_probabilistic_test (
)
{
print_spinner();
typedef matrix<scalar_type,2,1> sample_type;
std::vector<sample_type> samples;
std::vector<label_type> labels;
// First, get our labeled set of training data
generate_data(samples, labels);
typedef one_vs_all_trainer<any_trainer<sample_type,scalar_type>,label_type > ova_trainer;
ova_trainer trainer;
typedef polynomial_kernel<sample_type> poly_kernel;
typedef radial_basis_kernel<sample_type> rbf_kernel;
// make the binary trainers and set some parameters
krr_trainer<rbf_kernel> rbf_trainer;
svm_nu_trainer<poly_kernel> poly_trainer;
poly_trainer.set_kernel(poly_kernel(0.1, 1, 2));
rbf_trainer.set_kernel(rbf_kernel(0.1));
trainer.set_trainer(probabilistic(rbf_trainer, 3));
trainer.set_trainer(probabilistic(poly_trainer, 3), 1);
randomize_samples(samples, labels);
matrix<scalar_type> res = cross_validate_multiclass_trainer(trainer, samples, labels, 2);
print_spinner();
matrix<scalar_type> ans(3,3);
ans = 60, 0, 0,
0, 70, 0,
0, 0, 80;
DLIB_TEST_MSG(ans == res, "res: \n" << res);
one_vs_all_decision_function<ova_trainer> df = trainer.train(samples, labels);
DLIB_TEST(df.number_of_classes() == 3);
DLIB_TEST(df(samples[0]) == labels[0])
DLIB_TEST(df(samples[90]) == labels[90])
one_vs_all_decision_function<ova_trainer,
probabilistic_function<decision_function<poly_kernel> >, // This is the output of the poly_trainer
probabilistic_function<decision_function<rbf_kernel> > // This is the output of the rbf_trainer
> df2, df3;
df2 = df;
ofstream fout("df.dat", ios::binary);
serialize(df2, fout);
fout.close();
// load the function back in from disk and store it in df3.
ifstream fin("df.dat", ios::binary);
deserialize(df3, fin);
DLIB_TEST(df3(samples[0]) == labels[0])
DLIB_TEST(df3(samples[90]) == labels[90])
res = test_multiclass_decision_function(df3, samples, labels);
DLIB_TEST(res == ans);
}
void perform_test (
......@@ -187,6 +263,18 @@ namespace
dlog << LINFO << "run_test<int,float>()";
run_test<int,float>();
dlog << LINFO << "run_probabilistic_test<double,double>()";
run_probabilistic_test<double,double>();
dlog << LINFO << "run_probabilistic_test<int,double>()";
run_probabilistic_test<int,double>();
dlog << LINFO << "run_probabilistic_test<double,float>()";
run_probabilistic_test<double,float>();
dlog << LINFO << "run_probabilistic_test<int,float>()";
run_probabilistic_test<int,float>();
}
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment