Commit 1d777f2f authored by Davis King's avatar Davis King

turned the test_trainer function into the test_binary_decision_function

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402449
parent 1595f37a
...@@ -246,32 +246,27 @@ namespace dlib ...@@ -246,32 +246,27 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
typename trainer_type, typename dec_funct_type,
typename in_sample_vector_type, typename in_sample_vector_type,
typename in_scalar_vector_type typename in_scalar_vector_type
> >
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type> const matrix<typename dec_funct_type::scalar_type, 1, 2, typename dec_funct_type::mem_manager_type>
test_trainer_impl ( test_binary_decision_function_impl (
const trainer_type& trainer, const dec_funct_type& dec_funct,
const in_sample_vector_type& x_train,
const in_scalar_vector_type& y_train,
const in_sample_vector_type& x_test, const in_sample_vector_type& x_test,
const in_scalar_vector_type& y_test const in_scalar_vector_type& y_test
) )
{ {
typedef typename trainer_type::scalar_type scalar_type; typedef typename dec_funct_type::scalar_type scalar_type;
typedef typename trainer_type::sample_type sample_type; typedef typename dec_funct_type::sample_type sample_type;
typedef typename trainer_type::mem_manager_type mem_manager_type; typedef typename dec_funct_type::mem_manager_type mem_manager_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type; typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type; typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(is_binary_classification_problem(x_train,y_train) == true && DLIB_ASSERT( is_binary_classification_problem(x_test,y_test) == true,
is_binary_classification_problem(x_test,y_test) == true, "\tmatrix test_binary_decision_function()"
"\tmatrix test_trainer()"
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
<< "\n\t is_binary_classification_problem(x_train,y_train): "
<< ((is_binary_classification_problem(x_train,y_train))? "true":"false")
<< "\n\t is_binary_classification_problem(x_test,y_test): " << "\n\t is_binary_classification_problem(x_test,y_test): "
<< ((is_binary_classification_problem(x_test,y_test))? "true":"false")); << ((is_binary_classification_problem(x_test,y_test))? "true":"false"));
...@@ -284,11 +279,6 @@ namespace dlib ...@@ -284,11 +279,6 @@ namespace dlib
long num_pos_correct = 0; long num_pos_correct = 0;
long num_neg_correct = 0; long num_neg_correct = 0;
typename trainer_type::trained_function_type d;
// do the training
d = trainer.train(x_train,y_train);
// now test this trained object // now test this trained object
for (long i = 0; i < x_test.nr(); ++i) for (long i = 0; i < x_test.nr(); ++i)
...@@ -297,18 +287,18 @@ namespace dlib ...@@ -297,18 +287,18 @@ namespace dlib
if (y_test(i) == +1.0) if (y_test(i) == +1.0)
{ {
++num_pos; ++num_pos;
if (d(x_test(i)) >= 0) if (dec_funct(x_test(i)) >= 0)
++num_pos_correct; ++num_pos_correct;
} }
else if (y_test(i) == -1.0) else if (y_test(i) == -1.0)
{ {
++num_neg; ++num_neg;
if (d(x_test(i)) < 0) if (dec_funct(x_test(i)) < 0)
++num_neg_correct; ++num_neg_correct;
} }
else else
{ {
throw dlib::error("invalid input labels to the test_trainer() function"); throw dlib::error("invalid input labels to the test_binary_decision_function() function");
} }
} }
...@@ -320,22 +310,18 @@ namespace dlib ...@@ -320,22 +310,18 @@ namespace dlib
} }
template < template <
typename trainer_type, typename dec_funct_type,
typename in_sample_vector_type, typename in_sample_vector_type,
typename in_scalar_vector_type typename in_scalar_vector_type
> >
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type> const matrix<typename dec_funct_type::scalar_type, 1, 2, typename dec_funct_type::mem_manager_type>
test_trainer ( test_binary_decision_function (
const trainer_type& trainer, const dec_funct_type& dec_funct,
const in_sample_vector_type& x_train,
const in_scalar_vector_type& y_train,
const in_sample_vector_type& x_test, const in_sample_vector_type& x_test,
const in_scalar_vector_type& y_test const in_scalar_vector_type& y_test
) )
{ {
return test_trainer_impl(trainer, return test_binary_decision_function_impl(dec_funct,
vector_to_matrix(x_train),
vector_to_matrix(y_train),
vector_to_matrix(x_test), vector_to_matrix(x_test),
vector_to_matrix(y_test)); vector_to_matrix(y_test));
} }
...@@ -462,7 +448,8 @@ namespace dlib ...@@ -462,7 +448,8 @@ namespace dlib
train_neg_idx = (train_neg_idx+1)%x.nr(); train_neg_idx = (train_neg_idx+1)%x.nr();
} }
res += test_trainer(trainer,x_train,y_train,x_test,y_test); // do the training and testing
res += test_binary_decision_function(trainer.train(x_train,y_train),x_test,y_test);
} // for (long i = 0; i < folds; ++i) } // for (long i = 0; i < folds; ++i)
......
...@@ -338,33 +338,28 @@ namespace dlib ...@@ -338,33 +338,28 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
typename trainer_type, typename dec_funct_type,
typename in_sample_vector_type, typename in_sample_vector_type,
typename in_scalar_vector_type typename in_scalar_vector_type
> >
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type> const matrix<typename dec_funct_type::scalar_type, 1, 2, typename dec_funct_type::mem_manager_type>
test_trainer ( test_binary_decision_function (
const trainer_type& trainer, const dec_funct_type& trainer,
const in_sample_vector_type& x_train,
const in_scalar_vector_type& y_train,
const in_sample_vector_type& x_test, const in_sample_vector_type& x_test,
const in_scalar_vector_type& y_test const in_scalar_vector_type& y_test
); );
/*! /*!
requires requires
- is_binary_classification_problem(x_test,y_test) == true - is_binary_classification_problem(x_test,y_test) == true
- is_binary_classification_problem(x_train,y_train) == true - dec_funct_type == some kind of decision function object (e.g. decision_function)
- trainer_type == some kind of trainer object (e.g. svm_nu_trainer)
ensures ensures
- trains a single decision function by calling trainer.train(x_train,y_train) - tests the given decision function by calling on the x_test and y_test samples.
and tests the decision function on the x_test and y_test samples. - The test accuracy is returned in a column vector, let us call it R. Both
- The accuracy is returned in a column vector, let us call it R. Both
quantities in R are numbers between 0 and 1 which represent the fraction quantities in R are numbers between 0 and 1 which represent the fraction
of examples correctly classified. R(0) is the fraction of +1 examples of examples correctly classified. R(0) is the fraction of +1 examples
correctly classified and R(1) is the fraction of -1 examples correctly correctly classified and R(1) is the fraction of -1 examples correctly
classified. classified.
throws throws
- any exceptions thrown by trainer.train()
- std::bad_alloc - std::bad_alloc
!*/ !*/
......
...@@ -89,7 +89,7 @@ namespace dlib ...@@ -89,7 +89,7 @@ namespace dlib
matrix<scalar_type, 1, 2, mem_manager_type> temp_res; matrix<scalar_type, 1, 2, mem_manager_type> temp_res;
while (job_pipe.dequeue(j)) while (job_pipe.dequeue(j))
{ {
temp_res = test_trainer(j.trainer, j.x_train, j.y_train, j.x_test, j.y_test); temp_res = test_binary_decision_function(j.trainer.train(j.x_train, j.y_train), j.x_test, j.y_test);
res_pipe.enqueue(temp_res); res_pipe.enqueue(temp_res);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment