Commit 73082cdb authored by Davis King's avatar Davis King

Added the ability to automatically select a reasonable basis to the svm_c_ekm_trainer.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403618
parent abc86877
......@@ -10,6 +10,7 @@
#include "svm_c_linear_trainer.h"
#include "svm_c_ekm_trainer_abstract.h"
#include "../statistics.h"
#include "../rand.h"
#include <vector>
namespace dlib
......@@ -32,6 +33,10 @@ namespace dlib
{
verbose = false;
ekm_stale = true;
initial_basis_size = 5;
basis_size_increment = 5;
max_basis_size = 300;
}
explicit svm_c_ekm_trainer (
......@@ -50,6 +55,10 @@ namespace dlib
ocas.set_c(C);
verbose = false;
ekm_stale = true;
initial_basis_size = 5;
basis_size_increment = 5;
max_basis_size = 300;
}
void set_epsilon (
......@@ -145,6 +154,77 @@ namespace dlib
return (basis.size() != 0);
}
void clear_basis (
)
{
basis.set_size(0);
ekm.clear();
ekm_stale = true;
}
unsigned long get_max_basis_size (
) const
{
return max_basis_size;
}
void set_max_basis_size (
unsigned long max_basis_size_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(max_basis_size_ > 0,
"\t void svm_c_ekm_trainer::set_max_basis_size()"
<< "\n\t max_basis_size_ must be greater than 0"
<< "\n\t max_basis_size_: " << max_basis_size_
<< "\n\t this: " << this
);
max_basis_size = max_basis_size_;
}
unsigned long get_initial_basis_size (
) const
{
return initial_basis_size;
}
void set_initial_basis_size (
unsigned long initial_basis_size_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(initial_basis_size_ > 0,
"\t void svm_c_ekm_trainer::set_initial_basis_size()"
<< "\n\t initial_basis_size_ must be greater than 0"
<< "\n\t initial_basis_size_: " << initial_basis_size_
<< "\n\t this: " << this
);
initial_basis_size = initial_basis_size_;
}
unsigned long get_basis_size_increment (
) const
{
return basis_size_increment;
}
void set_basis_size_increment (
unsigned long basis_size_increment_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(basis_size_increment_ > 0,
"\t void svm_c_ekm_trainer::set_basis_size_increment()"
<< "\n\t basis_size_increment_ must be greater than 0"
<< "\n\t basis_size_increment_: " << basis_size_increment_
<< "\n\t this: " << this
);
basis_size_increment = basis_size_increment_;
}
void set_c (
scalar_type C
)
......@@ -211,8 +291,22 @@ namespace dlib
const in_scalar_vector_type& y
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
"\t decision_function svm_c_ekm_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.nr(): " << x.nr()
<< "\n\t y.nr(): " << y.nr()
<< "\n\t x.nc(): " << x.nc()
<< "\n\t y.nc(): " << y.nc()
<< "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
);
scalar_type obj;
return do_train(vector_to_matrix(x),vector_to_matrix(y),obj);
if (basis_loaded())
return do_train_user_basis(vector_to_matrix(x),vector_to_matrix(y),obj);
else
return do_train_auto_basis(vector_to_matrix(x),vector_to_matrix(y),obj);
}
template <
......@@ -225,7 +319,21 @@ namespace dlib
scalar_type& svm_objective
) const
{
return do_train(vector_to_matrix(x),vector_to_matrix(y),svm_objective);
// make sure requires clause is not broken
DLIB_ASSERT(is_binary_classification_problem(x,y) == true,
"\t decision_function svm_c_ekm_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.nr(): " << x.nr()
<< "\n\t y.nr(): " << y.nr()
<< "\n\t x.nc(): " << x.nc()
<< "\n\t y.nc(): " << y.nc()
<< "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
);
if (basis_loaded())
return do_train_user_basis(vector_to_matrix(x),vector_to_matrix(y),svm_objective);
else
return do_train_auto_basis(vector_to_matrix(x),vector_to_matrix(y),svm_objective);
}
......@@ -235,24 +343,18 @@ namespace dlib
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const decision_function<kernel_type> do_train (
const decision_function<kernel_type> do_train_user_basis (
const in_sample_vector_type& x,
const in_scalar_vector_type& y,
scalar_type& svm_objective
) const
/*!
requires
- basis_loaded() == true
ensures
- trains an SVM with the user supplied basis
!*/
{
// make sure requires clause is not broken
DLIB_ASSERT(basis_loaded() == true && is_binary_classification_problem(x,y) == true,
"\t decision_function svm_c_ekm_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.nr(): " << x.nr()
<< "\n\t y.nr(): " << y.nr()
<< "\n\t x.nc(): " << x.nc()
<< "\n\t y.nc(): " << y.nc()
<< "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y)
<< "\n\t basis_loaded(): " << basis_loaded()
);
if (ekm_stale)
{
ekm.load(kern, basis);
......@@ -298,6 +400,158 @@ namespace dlib
return final_df;
}
template <
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const decision_function<kernel_type> do_train_auto_basis (
const in_sample_vector_type& x,
const in_scalar_vector_type& y,
scalar_type& svm_objective
) const
{
std::vector<matrix<scalar_type,0,1, mem_manager_type> > proj_samples(x.size());
decision_function<linear_kernel<matrix<scalar_type,0,1, mem_manager_type> > > df;
// we will use a linearly_independent_subset_finder to store our basis set.
linearly_independent_subset_finder<kernel_type> lisf(get_kernel(), max_basis_size);
dlib::rand::kernel_1a rnd;
// first pick the initial basis set randomly
for (unsigned long i = 0; i < 10*initial_basis_size && lisf.dictionary_size() < initial_basis_size; ++i)
{
lisf.add(x(rnd.get_random_32bit_number()%x.size()));
}
svm_c_linear_trainer<linear_kernel<matrix<scalar_type,0,1,mem_manager_type> > > trainer(ocas);
const scalar_type min_epsilon = trainer.get_epsilon();
// while we are determining what the basis set will be we are going to use a very
// lose stopping condition. We will tighten it back up before producing the
// final decision_function.
trainer.set_epsilon(0.2);
scalar_type prev_svm_objective = std::numeric_limits<scalar_type>::max();
// This loop is where we try to generate a basis for SVM training. We will
// do this by repeatedly training the SVM and adding a few points which violate the
// margin to the basis in each iteration.
while (true)
{
running_stats<scalar_type> rs;
ekm.load(lisf);
// first project all samples into the span of the current basis
for (long i = 0; i < x.size(); ++i)
{
if (verbose)
{
scalar_type err;
proj_samples[i] = ekm.project(x(i), err);
rs.add(err);
}
else
{
proj_samples[i] = ekm.project(x(i));
}
}
// if the basis is already as big as it's going to get then just do the most
// accurate training right now.
if (lisf.dictionary_size() == max_basis_size)
trainer.set_epsilon(min_epsilon);
while (true)
{
// now do the training.
df = trainer.train(proj_samples, y, svm_objective);
if (svm_objective < prev_svm_objective)
break;
// If the training didn't reduce the objective more than last time then
// try lowering the epsilon and doing it again.
if (trainer.get_epsilon() > min_epsilon)
{
trainer.set_epsilon(std::max(trainer.get_epsilon()*0.5, min_epsilon));
if (verbose)
std::cout << "Reducing epsilon to " << trainer.get_epsilon() << std::endl;
}
else
break;
}
if (verbose)
{
std::cout << "\nMean EKM projection error: " << rs.mean() << std::endl;
std::cout << "Standard deviaion of EKM projection error: " << rs.stddev() << std::endl;
std::cout << "svm objective: " << svm_objective << std::endl;
std::cout << "basis size: " << lisf.dictionary_size() << std::endl;
}
// if we failed to make progress on this iteration then we are done
if (svm_objective >= prev_svm_objective)
break;
prev_svm_objective = svm_objective;
// now add more elements to the basis
unsigned long count = 0;
for (unsigned long j = 0;
(j < 100*basis_size_increment) && (count < basis_size_increment) && (lisf.dictionary_size() < max_basis_size);
++j)
{
// pick a random sample
const unsigned long idx = rnd.get_random_32bit_number()%x.size();
// If it is a margin violator then it is useful to add it into the basis set.
if (df(proj_samples[idx])*y(idx) < 1)
{
// Add the sample into the basis set if it is linearly independent of all the
// vectors already in the basis set.
if (lisf.add(x(idx)))
++count;
}
}
// if we couldn't add any more basis vectors then stop
if (count == 0)
{
if (verbose)
std::cout << "Stopping, couldn't add more basis vectors." << std::endl;
break;
}
}
// if we haven't already done so then make sure to run training with the tight epsilon
// before we return our results.
if (trainer.get_epsilon() > min_epsilon)
{
trainer.set_epsilon(min_epsilon);
df = trainer.train(proj_samples, y, svm_objective);
}
if (verbose)
{
std::cout << "Final svm objective: " << svm_objective << std::endl;
}
decision_function<kernel_type> final_df;
final_df = ekm.convert_to_decision_function(df.basis_vectors(0));
final_df.b = df.b;
// we don't need the ekm anymore so clear it out
ekm.clear();
return final_df;
}
/*!
CONVENTION
- if (ekm_stale) then
......@@ -309,6 +563,10 @@ namespace dlib
bool verbose;
kernel_type kern;
unsigned long max_basis_size;
unsigned long basis_size_increment;
unsigned long initial_basis_size;
matrix<sample_type,0,1,mem_manager_type> basis;
mutable empirical_kernel_map<kernel_type> ekm;
......
......@@ -24,10 +24,7 @@ namespace dlib
This object represents a tool for training the C formulation of
a support vector machine. It is implemented using the empirical_kernel_map
to kernelize the svm_c_linear_trainer. This makes it a very fast algorithm
but means the user must supply a set of basis vectors.
For details about the "basis vectors" see the empirical_kernel_map
documentation. In particular, see it's example program.
capable of learning from very large datasets.
!*/
public:
......@@ -48,6 +45,9 @@ namespace dlib
- #get_c_class2() == 1
- #get_epsilon() == 0.001
- #basis_loaded() == false
- #get_initial_basis_size() == 5
- #get_basis_size_increment() == 5
- #get_max_basis_size() == 300
- this object will not be verbose unless be_verbose() is called
!*/
......@@ -65,6 +65,9 @@ namespace dlib
- #get_c_class2() == C
- #get_epsilon() == 0.001
- #basis_loaded() == false
- #get_initial_basis_size() == 5
- #get_basis_size_increment() == 5
- #get_max_basis_size() == 300
- this object will not be verbose unless be_verbose() is called
!*/
......@@ -162,7 +165,78 @@ namespace dlib
) const;
/*!
ensures
- returns true if this object has been loaded with basis vectors and false otherwise.
- returns true if this object has been loaded with user supplied basis vectors and false otherwise.
!*/
void clear_basis (
);
/*!
ensures
- #basis_loaded() == false
!*/
unsigned long get_max_basis_size (
) const;
/*!
ensures
- returns the maximum number of basis vectors this object is allowed
to use. This parameter only matters when the user has not supplied
a basis via set_basis().
!*/
void set_max_basis_size (
unsigned long max_basis_size
);
/*!
requires
- max_basis_size > 0
ensures
- #get_max_basis_size() == max_basis_size
- if (get_initial_basis_size() < max_basis_size) then
- #get_initial_basis_size() == max_basis_size
!*/
unsigned long get_initial_basis_size (
) const;
/*!
ensures
- If the user does not supply a basis via set_basis() then this object
will generate one automatically. It does this by starting with
a small basis of size N and repeatedly adds basis vectors to it
until a stopping condition is reached. This function returns that
initial size N.
!*/
void set_initial_basis_size (
unsigned long initial_basis_size
);
/*!
requires
- initial_basis_size > 0
ensures
- #get_initial_basis_size() == initial_basis_size
- if (initial_basis_size > get_max_basis_size()) then
- #get_max_basis_size() == initial_basis_size
!*/
unsigned long get_basis_size_increment (
) const;
/*!
ensures
- If the user does not supply a basis via set_basis() then this object
will generate one automatically. It does this by starting with a small
basis and repeatedly adds sets of N basis vectors to it until a stopping
condition is reached. This function returns that increment size N.
!*/
void set_basis_size_increment (
unsigned long basis_size_increment
);
/*!
requires
- basis_size_increment > 0
ensures
- #get_basis_size_increment() == basis_size_increment
!*/
void set_c (
......@@ -230,7 +304,6 @@ namespace dlib
) const;
/*!
requires
- basis_loaded() == true
- is_binary_classification_problem(x,y) == true
- x == a matrix or something convertible to a matrix via vector_to_matrix().
Also, x should contain sample_type objects.
......@@ -239,6 +312,11 @@ namespace dlib
ensures
- trains a C support vector classifier given the training samples in x and
labels in y.
- if (basis_loaded()) then
- training will be carried out in the span of the user supplied basis vectors
- else
- this object will attempt to automatically select an appropriate basis
- returns a decision function F with the following properties:
- if (new_x is a sample predicted have +1 label) then
- F(new_x) >= 0
......@@ -257,7 +335,6 @@ namespace dlib
) const;
/*!
requires
- basis_loaded() == true
- is_binary_classification_problem(x,y) == true
- x == a matrix or something convertible to a matrix via vector_to_matrix().
Also, x should contain sample_type objects.
......@@ -266,6 +343,11 @@ namespace dlib
ensures
- trains a C support vector classifier given the training samples in x and
labels in y.
- if (basis_loaded()) then
- training will be carried out in the span of the user supplied basis vectors
- else
- this object will attempt to automatically select an appropriate basis
- #svm_objective == the final value of the SVM objective function
- returns a decision function F with the following properties:
- if (new_x is a sample predicted have +1 label) then
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment