Commit fedf69e1 authored by Davis King's avatar Davis King

added the test_trainer and cross_validate_trainer_threaded functions

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402445
parent 7f9e7a53
......@@ -243,6 +243,103 @@ namespace dlib
};
// ----------------------------------------------------------------------------------------
template <
typename trainer_type,
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type>
test_trainer_impl (
const trainer_type& trainer,
const in_sample_vector_type& x_train,
const in_scalar_vector_type& y_train,
const in_sample_vector_type& x_test,
const in_scalar_vector_type& y_test
)
{
typedef typename trainer_type::scalar_type scalar_type;
typedef typename trainer_type::sample_type sample_type;
typedef typename trainer_type::mem_manager_type mem_manager_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
// make sure requires clause is not broken
DLIB_ASSERT(is_binary_classification_problem(x_train,y_train) == true &&
is_binary_classification_problem(x_test,y_test) == true,
"\tmatrix test_trainer()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t is_binary_classification_problem(x_train,y_train): "
<< ((is_binary_classification_problem(x_train,y_train))? "true":"false")
<< "\n\t is_binary_classification_problem(x_test,y_test): "
<< ((is_binary_classification_problem(x_test,y_test))? "true":"false"));
// count the number of positive and negative examples
long num_pos = 0;
long num_neg = 0;
long num_pos_correct = 0;
long num_neg_correct = 0;
typename trainer_type::trained_function_type d;
// do the training
d = trainer.train(x_train,y_train);
// now test this trained object
for (long i = 0; i < x_test.nr(); ++i)
{
// if this is a positive example
if (y_test(i) == +1.0)
{
++num_pos;
if (d(x_test(i)) >= 0)
++num_pos_correct;
}
else if (y_test(i) == -1.0)
{
++num_neg;
if (d(x_test(i)) < 0)
++num_neg_correct;
}
else
{
throw dlib::error("invalid input labels to the test_trainer() function");
}
}
matrix<scalar_type, 1, 2, mem_manager_type> res;
res(0) = (scalar_type)num_pos_correct/(scalar_type)(num_pos);
res(1) = (scalar_type)num_neg_correct/(scalar_type)(num_neg);
return res;
}
template <
typename trainer_type,
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type>
test_trainer (
const trainer_type& trainer,
const in_sample_vector_type& x_train,
const in_scalar_vector_type& y_train,
const in_sample_vector_type& x_test,
const in_scalar_vector_type& y_test
)
{
return test_trainer_impl(trainer,
vector_to_matrix(x_train),
vector_to_matrix(y_train),
vector_to_matrix(x_test),
vector_to_matrix(y_test));
}
// ----------------------------------------------------------------------------------------
template <
......@@ -270,9 +367,6 @@ namespace dlib
"\tmatrix cross_validate_trainer()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.nr(): " << x.nr()
<< "\n\t y.nr(): " << y.nr()
<< "\n\t x.nc(): " << x.nc()
<< "\n\t y.nc(): " << y.nc()
<< "\n\t folds: " << folds
<< "\n\t is_binary_classification_problem(x,y): " << ((is_binary_classification_problem(x,y))? "true":"false")
);
......@@ -295,8 +389,6 @@ namespace dlib
const long num_neg_test_samples = num_neg/folds;
const long num_neg_train_samples = num_neg - num_neg_test_samples;
long num_pos_correct = 0;
long num_neg_correct = 0;
typename trainer_type::trained_function_type d;
sample_vector_type x_test, x_train;
......@@ -309,6 +401,9 @@ namespace dlib
long pos_idx = 0;
long neg_idx = 0;
matrix<scalar_type, 1, 2, mem_manager_type> res;
set_all_elements(res,0);
for (long i = 0; i < folds; ++i)
{
long cur = 0;
......@@ -367,35 +462,11 @@ namespace dlib
train_neg_idx = (train_neg_idx+1)%x.nr();
}
// do the training
d = trainer.train(x_train,y_train);
// now test this fold
for (long i = 0; i < x_test.nr(); ++i)
{
// if this is a positive example
if (y_test(i) == +1.0)
{
if (d(x_test(i)) >= 0)
++num_pos_correct;
}
else if (y_test(i) == -1.0)
{
if (d(x_test(i)) < 0)
++num_neg_correct;
}
else
{
throw dlib::error("invalid input labels to the cross_validate_trainer() function");
}
}
res += test_trainer(trainer,x_train,y_train,x_test,y_test);
} // for (long i = 0; i < folds; ++i)
matrix<scalar_type, 1, 2, mem_manager_type> res;
res(0) = (scalar_type)num_pos_correct/(scalar_type)(num_pos_test_samples*folds);
res(1) = (scalar_type)num_neg_correct/(scalar_type)(num_neg_test_samples*folds);
return res;
return res/(scalar_type)folds;
}
template <
......
......@@ -335,6 +335,39 @@ namespace dlib
- std::bad_alloc
!*/
// ----------------------------------------------------------------------------------------
template <
typename trainer_type,
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type>
test_trainer (
const trainer_type& trainer,
const in_sample_vector_type& x_train,
const in_scalar_vector_type& y_train,
const in_sample_vector_type& x_test,
const in_scalar_vector_type& y_test
);
/*!
requires
- is_binary_classification_problem(x_test,y_test) == true
- is_binary_classification_problem(x_train,y_train) == true
- trainer_type == some kind of trainer object (e.g. svm_nu_trainer)
ensures
- trains a single decision function by calling trainer.train(x_train,y_train)
and tests the decision function on the x_test and y_test samples.
- The accuracy is returned in a column vector, let us call it R. Both
quantities in R are numbers between 0 and 1 which represent the fraction
of examples correctly classified. R(0) is the fraction of +1 examples
correctly classified and R(1) is the fraction of -1 examples correctly
classified.
throws
- any exceptions thrown by trainer.train()
- std::bad_alloc
!*/
// ----------------------------------------------------------------------------------------
template <
......
// Copyright (C) 2008 Davis E. King (davisking@users.sourceforge.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_SVm_THREADED_
#define DLIB_SVm_THREADED_
#include "svm_threaded_abstract.h"
#include "svm.h"
#include <cmath>
#include <limits>
#include <sstream>
#include "../matrix.h"
#include "../algs.h"
#include "../serialize.h"
#include "function.h"
#include "kernel.h"
#include "../threads.h"
#include <vector>
#include "../smart_pointers.h"
#include "../pipe.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
namespace cvtti_helpers
{
template <typename trainer_type>
struct job
{
typedef typename trainer_type::scalar_type scalar_type;
typedef typename trainer_type::sample_type sample_type;
typedef typename trainer_type::mem_manager_type mem_manager_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
trainer_type trainer;
sample_vector_type x_test, x_train;
scalar_vector_type y_test, y_train;
};
template <typename trainer_type>
void swap(
job<trainer_type>& a,
job<trainer_type>& b
)
{
exchange(a.trainer, b.trainer);
exchange(a.x_test, b.x_test);
exchange(a.y_test, b.y_test);
exchange(a.x_train, b.x_train);
exchange(a.y_train, b.y_train);
}
template <typename trainer_type>
class a_thread : multithreaded_object
{
public:
typedef typename trainer_type::scalar_type scalar_type;
typedef typename trainer_type::sample_type sample_type;
typedef typename trainer_type::mem_manager_type mem_manager_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
explicit a_thread( long num_threads) : job_pipe(1), res_pipe(3)
{
for (long i = 0; i < num_threads; ++i)
{
register_thread(*this, &a_thread::thread);
}
start();
}
~a_thread()
{
// disable the job_pipe so that the threads will unblock and terminate
job_pipe.disable();
wait();
}
typename pipe<job<trainer_type> > ::kernel_1a job_pipe;
typename pipe<matrix<scalar_type, 1, 2, mem_manager_type> >::kernel_1a res_pipe;
private:
void thread()
{
job<trainer_type> j;
matrix<scalar_type, 1, 2, mem_manager_type> temp_res;
while (job_pipe.dequeue(j))
{
temp_res = test_trainer(j.trainer, j.x_train, j.y_train, j.x_test, j.y_test);
res_pipe.enqueue(temp_res);
}
}
};
}
template <
typename trainer_type,
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type>
cross_validate_trainer_threaded_impl (
const trainer_type& trainer,
const in_sample_vector_type& x,
const in_scalar_vector_type& y,
const long folds,
const long num_threads
)
{
using namespace dlib::cvtti_helpers;
typedef typename trainer_type::scalar_type scalar_type;
typedef typename trainer_type::sample_type sample_type;
typedef typename trainer_type::mem_manager_type mem_manager_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
// make sure requires clause is not broken
DLIB_ASSERT(is_binary_classification_problem(x,y) == true &&
1 < folds && folds <= x.nr() &&
num_threads > 0,
"\tmatrix cross_validate_trainer()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.nr(): " << x.nr()
<< "\n\t folds: " << folds
<< "\n\t num_threads: " << num_threads
<< "\n\t is_binary_classification_problem(x,y): " << ((is_binary_classification_problem(x,y))? "true":"false")
);
// count the number of positive and negative examples
long num_pos = 0;
long num_neg = 0;
for (long r = 0; r < y.nr(); ++r)
{
if (y(r) == +1.0)
++num_pos;
else
++num_neg;
}
// figure out how many positive and negative examples we will have in each fold
const long num_pos_test_samples = num_pos/folds;
const long num_pos_train_samples = num_pos - num_pos_test_samples;
const long num_neg_test_samples = num_neg/folds;
const long num_neg_train_samples = num_neg - num_neg_test_samples;
typename trainer_type::trained_function_type d;
long pos_idx = 0;
long neg_idx = 0;
job<trainer_type> j;
a_thread<trainer_type> threads(num_threads);
for (long i = 0; i < folds; ++i)
{
j.x_test.set_size (num_pos_test_samples + num_neg_test_samples);
j.y_test.set_size (num_pos_test_samples + num_neg_test_samples);
j.x_train.set_size(num_pos_train_samples + num_neg_train_samples);
j.y_train.set_size(num_pos_train_samples + num_neg_train_samples);
j.trainer = trainer;
long cur = 0;
// load up our positive test samples
while (cur < num_pos_test_samples)
{
if (y(pos_idx) == +1.0)
{
j.x_test(cur) = x(pos_idx);
j.y_test(cur) = +1.0;
++cur;
}
pos_idx = (pos_idx+1)%x.nr();
}
// load up our negative test samples
while (cur < j.x_test.nr())
{
if (y(neg_idx) == -1.0)
{
j.x_test(cur) = x(neg_idx);
j.y_test(cur) = -1.0;
++cur;
}
neg_idx = (neg_idx+1)%x.nr();
}
// load the training data from the data following whatever we loaded
// as the testing data
long train_pos_idx = pos_idx;
long train_neg_idx = neg_idx;
cur = 0;
// load up our positive train samples
while (cur < num_pos_train_samples)
{
if (y(train_pos_idx) == +1.0)
{
j.x_train(cur) = x(train_pos_idx);
j.y_train(cur) = +1.0;
++cur;
}
train_pos_idx = (train_pos_idx+1)%x.nr();
}
// load up our negative train samples
while (cur < j.x_train.nr())
{
if (y(train_neg_idx) == -1.0)
{
j.x_train(cur) = x(train_neg_idx);
j.y_train(cur) = -1.0;
++cur;
}
train_neg_idx = (train_neg_idx+1)%x.nr();
}
// add this job to the job pipe so that the threads
// will process it
threads.job_pipe.enqueue(j);
} // for (long i = 0; i < folds; ++i)
matrix<scalar_type, 1, 2, mem_manager_type> res;
matrix<scalar_type, 1, 2, mem_manager_type> temp_res;
set_all_elements(res,0);
// now wait for the threads to finish
for (long i = 0; i < folds; ++i)
{
threads.res_pipe.dequeue(temp_res);
res += temp_res;
}
return res/(scalar_type)folds;
}
template <
typename trainer_type,
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type>
cross_validate_trainer_threaded (
const trainer_type& trainer,
const in_sample_vector_type& x,
const in_scalar_vector_type& y,
const long folds,
const long num_threads
)
{
return cross_validate_trainer_threaded_impl(trainer,
vector_to_matrix(x),
vector_to_matrix(y),
folds,
num_threads);
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_SVm_THREADED_
// Copyright (C) 2008 Davis E. King (davisking@users.sourceforge.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_SVm_THREADED_ABSTRACT_
#ifdef DLIB_SVm_THREADED_ABSTRACT_
#include "../matrix/matrix_abstract.h"
#include "../algs.h"
#include "../svm.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename trainer_type,
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type>
cross_validate_trainer_threaded (
const trainer_type& trainer,
const in_sample_vector_type& x,
const in_scalar_vector_type& y,
const long folds,
const long num_threads
);
/*!
requires
- is_binary_classification_problem(x,y) == true
- 1 < folds <= x.nr()
- trainer_type == some kind of trainer object (e.g. svm_nu_trainer)
- num_threads > 0
ensures
- performs k-fold cross validation by using the given trainer to solve the
given binary classification problem for the given number of folds.
Each fold is tested using the output of the trainer and the average
classification accuracy from all folds is returned.
- uses num_threads threads of execution in doing the cross validation.
- The accuracy is returned in a column vector, let us call it R. Both
quantities in R are numbers between 0 and 1 which represent the fraction
of examples correctly classified. R(0) is the fraction of +1 examples
correctly classified and R(1) is the fraction of -1 examples correctly
classified.
- The number of folds used is given by the folds argument.
throws
- any exceptions thrown by trainer.train()
- std::bad_alloc
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_SVm_THREADED_ABSTRACT_
// Copyright (C) 2008 Davis E. King (davisking@users.sourceforge.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_SVm_THREADED_HEADER
#define DLIB_SVm_THREADED_HEADER
#include "svm.h"
#include "svm/svm_threaded.h"
#endif // DLIB_SVm_THREADED_HEADER
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment