Commit c6cb4d81 authored by Davis King's avatar Davis King

Changed the kcentroid so that you can tell it to keep the most linearly

independent vectors rather than the newest vectors.  I then changed the
svm_pegasos object so that it has a max number of support vector setting
so that the user can supply an upper limit on the number of support
vectors to use.  Note that these changes broke backwards compatibility
with the previous serialized format of these objects.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402928
parent 56ab4f8b
...@@ -19,13 +19,46 @@ namespace dlib ...@@ -19,13 +19,46 @@ namespace dlib
class kcentroid class kcentroid
{ {
/*! /*!
This is an implementation of an online algorithm for recursively estimating the This object represents a weighted sum of sample points in a kernel induced
centroid of a sequence of training points. It uses the sparsification technique feature space. It can be used to kernelized any algorithm that requires only
described in the paper The Kernel Recursive Least Squares Algorithm by Yaakov Engel. the ability to perform vector addition, subtraction, scalar multiplication,
and inner products. It uses the sparsification technique described in the
paper The Kernel Recursive Least Squares Algorithm by Yaakov Engel.
To understand the code it would also be useful to consult page 114 of the book
Kernel Methods for Pattern Analysis by Taylor and Cristianini as well as page 554
(particularly equation 18.31) of the book Learning with Kernels by Scholkopf and
Smola. Everything you really need to know is in the Engel paper. But the other
books help give more perspective on the issues involved.
INITIAL VALUE
- min_strength == 0
- min_vect_idx == 0
- K_inv.size() == 0
- K.size() == 0
- dictionary.size() == 0
- bias == 0
- bias_is_stale == false
CONVENTION
- max_dictionary_size() == my_max_dictionary_size
- get_kernel() == kernel
- K.nr() == dictionary.size()
- K.nc() == dictionary.size()
- for all valid r,c:
- K(r,c) == kernel(dictionary[r], dictionary[c])
- K_inv == inv(K)
- if (dictionary.size() == my_max_dictionary_size && my_remove_oldest_first == false) then
- for all valid 0 < i < dictionary.size():
- Let STRENGTHS[i] == the delta you would get for dictionary[i] (i.e. Approximately
Linearly Dependent value) if you removed dictionary[i] from this object and then
tried to add it back in.
- min_strength == the minimum value from STRENGTHS
- min_vect_idx == the index of the element in STRENGTHS with the smallest value
To understand the code it would also be useful to consult page 114 of the book Kernel
Methods for Pattern Analysis by Taylor and Cristianini as well as page 554
(particularly equation 18.31) of the book Learning with Kernels by Scholkopf and Smola.
!*/ !*/
public: public:
...@@ -37,8 +70,10 @@ namespace dlib ...@@ -37,8 +70,10 @@ namespace dlib
explicit kcentroid ( explicit kcentroid (
const kernel_type& kernel_, const kernel_type& kernel_,
scalar_type tolerance_ = 0.001, scalar_type tolerance_ = 0.001,
unsigned long max_dictionary_size_ = 1000000 unsigned long max_dictionary_size_ = 1000000,
bool remove_oldest_first_ = true
) : ) :
my_remove_oldest_first(remove_oldest_first_),
kernel(kernel_), kernel(kernel_),
my_tolerance(tolerance_), my_tolerance(tolerance_),
my_max_dictionary_size(max_dictionary_size_), my_max_dictionary_size(max_dictionary_size_),
...@@ -46,11 +81,12 @@ namespace dlib ...@@ -46,11 +81,12 @@ namespace dlib
bias_is_stale(false) bias_is_stale(false)
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(tolerance_ >= 0, DLIB_ASSERT(tolerance_ >= 0 && max_dictionary_size_ > 0,
"\tkcentroid::kcentroid()" "\tkcentroid::kcentroid()"
<< "\n\t You have to give a positive tolerance" << "\n\t You have to give a positive tolerance"
<< "\n\t this: " << this << "\n\t this: " << this
<< "\n\t tolerance: " << tolerance_ << "\n\t tolerance_: " << tolerance_
<< "\n\t max_dictionary_size_: " << max_dictionary_size_
); );
clear_dictionary(); clear_dictionary();
...@@ -66,6 +102,12 @@ namespace dlib ...@@ -66,6 +102,12 @@ namespace dlib
return my_max_dictionary_size; return my_max_dictionary_size;
} }
bool remove_oldest_first (
) const
{
return my_remove_oldest_first;
}
const kernel_type& get_kernel ( const kernel_type& get_kernel (
) const ) const
{ {
...@@ -77,6 +119,8 @@ namespace dlib ...@@ -77,6 +119,8 @@ namespace dlib
dictionary.clear(); dictionary.clear();
alpha.clear(); alpha.clear();
min_strength = 0;
min_vect_idx = 0;
K_inv.set_size(0,0); K_inv.set_size(0,0);
K.set_size(0,0); K.set_size(0,0);
samples_seen = 0; samples_seen = 0;
...@@ -215,6 +259,10 @@ namespace dlib ...@@ -215,6 +259,10 @@ namespace dlib
kcentroid& item kcentroid& item
) )
{ {
exchange(min_strength, item.min_strength);
exchange(min_vect_idx, item.min_vect_idx);
exchange(my_remove_oldest_first, item.my_remove_oldest_first);
exchange(kernel, item.kernel); exchange(kernel, item.kernel);
dictionary.swap(item.dictionary); dictionary.swap(item.dictionary);
alpha.swap(item.alpha); alpha.swap(item.alpha);
...@@ -234,6 +282,10 @@ namespace dlib ...@@ -234,6 +282,10 @@ namespace dlib
friend void serialize(const kcentroid& item, std::ostream& out) friend void serialize(const kcentroid& item, std::ostream& out)
{ {
serialize(item.min_strength, out);
serialize(item.min_vect_idx, out);
serialize(item.my_remove_oldest_first, out);
serialize(item.kernel, out); serialize(item.kernel, out);
serialize(item.dictionary, out); serialize(item.dictionary, out);
serialize(item.alpha, out); serialize(item.alpha, out);
...@@ -248,6 +300,10 @@ namespace dlib ...@@ -248,6 +300,10 @@ namespace dlib
friend void deserialize(kcentroid& item, std::istream& in) friend void deserialize(kcentroid& item, std::istream& in)
{ {
deserialize(item.min_strength, in);
deserialize(item.min_vect_idx, in);
deserialize(item.my_remove_oldest_first, in);
deserialize(item.kernel, in); deserialize(item.kernel, in);
deserialize(item.dictionary, in); deserialize(item.dictionary, in);
deserialize(item.alpha, in); deserialize(item.alpha, in);
...@@ -313,6 +369,7 @@ namespace dlib ...@@ -313,6 +369,7 @@ namespace dlib
if (do_test) if (do_test)
{ {
refresh_bias();
test_result = std::sqrt(kx + bias - 2*trans(vector_to_matrix(alpha))*k); test_result = std::sqrt(kx + bias - 2*trans(vector_to_matrix(alpha))*k);
} }
...@@ -323,13 +380,29 @@ namespace dlib ...@@ -323,13 +380,29 @@ namespace dlib
// if this new vector isn't approximately linearly dependent on the vectors // if this new vector isn't approximately linearly dependent on the vectors
// in our dictionary. // in our dictionary.
if (delta > my_tolerance) if (delta > min_strength && delta > my_tolerance)
{ {
bool need_to_update_min_strength = false;
if (dictionary.size() >= my_max_dictionary_size) if (dictionary.size() >= my_max_dictionary_size)
{ {
// We need to remove one of the old members of the dictionary before // We need to remove one of the old members of the dictionary before
// we proceed with adding a new one. So remove the oldest dictionary vector. // we proceed with adding a new one.
const long idx_to_remove = 0; long idx_to_remove;
if (my_remove_oldest_first)
{
// remove the oldest one
idx_to_remove = 0;
}
else
{
// if we have never computed the min_strength then we should compute it
if (min_strength == 0)
recompute_min_strength();
// select the dictionary vector that is most linearly dependent for removal
idx_to_remove = min_vect_idx;
need_to_update_min_strength = true;
}
remove_dictionary_vector(idx_to_remove); remove_dictionary_vector(idx_to_remove);
...@@ -377,6 +450,13 @@ namespace dlib ...@@ -377,6 +450,13 @@ namespace dlib
alpha[i] *= cscale; alpha[i] *= cscale;
} }
alpha.push_back(xscale); alpha.push_back(xscale);
if (need_to_update_min_strength)
{
// now we have to recompute the min_strength in this case
recompute_min_strength();
}
} }
else else
{ {
...@@ -432,6 +512,35 @@ namespace dlib ...@@ -432,6 +512,35 @@ namespace dlib
K = removerc(K,i,i); K = removerc(K,i,i);
} }
void recompute_min_strength (
)
/*!
ensures
- recomputes the min_strength and min_vect_idx values
so that they are correct with respect to the CONVENTION
- uses the this->a variable so after this function runs that variable
will contain a different value.
!*/
{
min_strength = std::numeric_limits<scalar_type>::max();
// here we loop over each dictionary vector and compute what its delta would be if
// we were to remove it from the dictionary and then try to add it back in.
for (unsigned long i = 0; i < dictionary.size(); ++i)
{
// compute a = K_inv*k but where dictionary vector i has been removed
a = (removerc(K_inv,i,i) - remove_row(colm(K_inv,i)/K_inv(i,i),i)*remove_col(rowm(K_inv,i),i)) *
(remove_row(colm(K,i),i));
scalar_type delta = K(i,i) - trans(remove_row(colm(K,i),i))*a;
if (delta < min_strength)
{
min_strength = delta;
min_vect_idx = i;
}
}
}
typedef std_allocator<sample_type, mem_manager_type> alloc_sample_type; typedef std_allocator<sample_type, mem_manager_type> alloc_sample_type;
...@@ -440,6 +549,10 @@ namespace dlib ...@@ -440,6 +549,10 @@ namespace dlib
typedef std::vector<scalar_type,alloc_scalar_type> alpha_vector_type; typedef std::vector<scalar_type,alloc_scalar_type> alpha_vector_type;
scalar_type min_strength;
unsigned long min_vect_idx;
bool my_remove_oldest_first;
kernel_type kernel; kernel_type kernel;
dictionary_vector_type dictionary; dictionary_vector_type dictionary;
alpha_vector_type alpha; alpha_vector_type alpha;
......
...@@ -39,10 +39,9 @@ namespace dlib ...@@ -39,10 +39,9 @@ namespace dlib
Also note that the algorithm internally keeps a set of "dictionary vectors" Also note that the algorithm internally keeps a set of "dictionary vectors"
that are used to represent the centroid. You can force the algorithm to use that are used to represent the centroid. You can force the algorithm to use
no more than a set number of vectors by setting the 3rd constructor argument no more than a set number of vectors by setting the 3rd constructor argument
to whatever you want. However, note that doing this causes the algorithm to whatever you want.
to bias it's results towards more recent training examples.
This object also uses the sparsification technique described in the paper The This object uses the sparsification technique described in the paper The
Kernel Recursive Least Squares Algorithm by Yaakov Engel. This technique Kernel Recursive Least Squares Algorithm by Yaakov Engel. This technique
allows us to keep the number of dictionary vectors down to a minimum. In fact, allows us to keep the number of dictionary vectors down to a minimum. In fact,
the object has a user selectable tolerance parameter that controls the trade off the object has a user selectable tolerance parameter that controls the trade off
...@@ -58,16 +57,19 @@ namespace dlib ...@@ -58,16 +57,19 @@ namespace dlib
explicit kcentroid ( explicit kcentroid (
const kernel_type& kernel_, const kernel_type& kernel_,
scalar_type tolerance_ = 0.001, scalar_type tolerance_ = 0.001,
unsigned long max_dictionary_size_ = 1000000 unsigned long max_dictionary_size_ = 1000000,
bool remove_oldest_first_ = true
); );
/*! /*!
requires requires
- tolerance >= 0 - tolerance >= 0
- max_dictionary_size_ > 0
ensures ensures
- this object is properly initialized - this object is properly initialized
- #tolerance() == tolerance_ - #tolerance() == tolerance_
- #get_kernel() == kernel_ - #get_kernel() == kernel_
- #max_dictionary_size() == max_dictionary_size_ - #max_dictionary_size() == max_dictionary_size_
- #remove_oldest_first() == remove_oldest_first_
!*/ !*/
const kernel_type& get_kernel ( const kernel_type& get_kernel (
...@@ -86,6 +88,19 @@ namespace dlib ...@@ -86,6 +88,19 @@ namespace dlib
greater than max_dictionary_size(). greater than max_dictionary_size().
!*/ !*/
bool remove_oldest_first (
) const;
/*!
ensures
- When the maximum dictionary size is reached then this object sometimes
needs to discard dictionary vectors when new samples are added via
one of the train functions. If remove_oldest_first() returns true then
this object discards the oldest dictionary vectors when the maximum
dictionary size is reached. Otherise, if this function returns false
then it means that this object discards the most linearly dependent
dictionary vectors.
!*/
unsigned long dictionary_size ( unsigned long dictionary_size (
) const; ) const;
/*! /*!
......
...@@ -32,31 +32,35 @@ namespace dlib ...@@ -32,31 +32,35 @@ namespace dlib
svm_pegasos ( svm_pegasos (
) : ) :
max_sv(40),
lambda(0.0001), lambda(0.0001),
tau(0.01), tau(0.01),
tolerance(0.01), tolerance(0.01),
train_count(0), train_count(0),
w(offset_kernel<kernel_type>(kernel,tau),tolerance) w(offset_kernel<kernel_type>(kernel,tau),tolerance, max_sv, false)
{ {
} }
svm_pegasos ( svm_pegasos (
const kernel_type& kernel_, const kernel_type& kernel_,
const scalar_type& lambda_, const scalar_type& lambda_,
const scalar_type& tolerance_ const scalar_type& tolerance_,
unsigned long max_num_sv
) : ) :
max_sv(max_num_sv),
kernel(kernel_), kernel(kernel_),
lambda(lambda_), lambda(lambda_),
tau(0.01), tau(0.01),
tolerance(tolerance_), tolerance(tolerance_),
train_count(0), train_count(0),
w(offset_kernel<kernel_type>(kernel,tau),tolerance) w(offset_kernel<kernel_type>(kernel,tau),tolerance, max_sv, false)
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(lambda > 0 && tolerance > 0, DLIB_ASSERT(lambda > 0 && tolerance > 0 && max_num_sv > 0,
"\tsvm_pegasos::svm_pegasos(kernel,lambda,tolerance)" "\tsvm_pegasos::svm_pegasos(kernel,lambda,tolerance)"
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
<< "\n\t lambda: " << lambda << "\n\t lambda: " << lambda
<< "\n\t max_num_sv: " << max_num_sv
); );
} }
...@@ -64,7 +68,7 @@ namespace dlib ...@@ -64,7 +68,7 @@ namespace dlib
) )
{ {
// reset the w vector back to its initial state // reset the w vector back to its initial state
w = kc_type(offset_kernel<kernel_type>(kernel,tau),tolerance); w = kc_type(offset_kernel<kernel_type>(kernel,tau),tolerance, max_sv, false);
train_count = 0; train_count = 0;
} }
...@@ -76,6 +80,26 @@ namespace dlib ...@@ -76,6 +80,26 @@ namespace dlib
clear(); clear();
} }
void set_max_num_sv (
unsigned long max_num_sv
)
{
// make sure requires clause is not broken
DLIB_ASSERT(max_num_sv > 0,
"\tvoid svm_pegasos::set_max_num_sv(max_num_sv)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t max_num_sv: " << max_num_sv
);
max_sv = max_num_sv;
clear();
}
unsigned long get_max_num_sv (
) const
{
return max_sv;
}
void set_tolerance ( void set_tolerance (
double tol double tol
) )
...@@ -180,6 +204,7 @@ namespace dlib ...@@ -180,6 +204,7 @@ namespace dlib
svm_pegasos& item svm_pegasos& item
) )
{ {
exchange(max_sv, item.max_sv);
exchange(kernel, item.kernel); exchange(kernel, item.kernel);
exchange(lambda, item.lambda); exchange(lambda, item.lambda);
exchange(tau, item.tau); exchange(tau, item.tau);
...@@ -190,6 +215,7 @@ namespace dlib ...@@ -190,6 +215,7 @@ namespace dlib
friend void serialize(const svm_pegasos& item, std::ostream& out) friend void serialize(const svm_pegasos& item, std::ostream& out)
{ {
serialize(item.max_sv, out);
serialize(item.kernel, out); serialize(item.kernel, out);
serialize(item.lambda, out); serialize(item.lambda, out);
serialize(item.tau, out); serialize(item.tau, out);
...@@ -200,6 +226,7 @@ namespace dlib ...@@ -200,6 +226,7 @@ namespace dlib
friend void deserialize(svm_pegasos& item, std::istream& in) friend void deserialize(svm_pegasos& item, std::istream& in)
{ {
deserialize(item.max_sv, in);
deserialize(item.kernel, in); deserialize(item.kernel, in);
deserialize(item.lambda, in); deserialize(item.lambda, in);
deserialize(item.tau, in); deserialize(item.tau, in);
...@@ -210,6 +237,7 @@ namespace dlib ...@@ -210,6 +237,7 @@ namespace dlib
private: private:
unsigned long max_sv;
kernel_type kernel; kernel_type kernel;
scalar_type lambda; scalar_type lambda;
scalar_type tau; scalar_type tau;
......
...@@ -59,23 +59,27 @@ namespace dlib ...@@ -59,23 +59,27 @@ namespace dlib
- #get_lambda() == 0.0001 - #get_lambda() == 0.0001
- #get_tolerance() == 0.01 - #get_tolerance() == 0.01
- #get_train_count() == 0 - #get_train_count() == 0
- #get_max_num_sv() == 40
!*/ !*/
svm_pegasos ( svm_pegasos (
const kernel_type& kernel_, const kernel_type& kernel_,
const scalar_type& lambda_, const scalar_type& lambda_,
const scalar_type& tolerance_ const scalar_type& tolerance_,
unsigned long max_num_sv
); );
/*! /*!
requires requires
- lambda_ > 0 - lambda_ > 0
- tolerance_ > 0 - tolerance_ > 0
- max_num_sv > 0
ensures ensures
- this object is properly initialized - this object is properly initialized
- #get_lambda() == lambda_ - #get_lambda() == lambda_
- #get_tolerance() == tolerance_ - #get_tolerance() == tolerance_
- #get_kernel() == kernel_ - #get_kernel() == kernel_
- #get_train_count() == 0 - #get_train_count() == 0
- #get_max_num_sv() == max_num_sv
!*/ !*/
void clear ( void clear (
...@@ -115,6 +119,14 @@ namespace dlib ...@@ -115,6 +119,14 @@ namespace dlib
decision function but will use more support vectors. decision function but will use more support vectors.
!*/ !*/
unsigned long get_max_num_sv (
) const;
/*!
ensures
- returns the maximum number of support vectors this object is
allowed to use.
!*/
const kernel_type get_kernel ( const kernel_type get_kernel (
) const; ) const;
/*! /*!
...@@ -144,6 +156,18 @@ namespace dlib ...@@ -144,6 +156,18 @@ namespace dlib
(i.e. clears any memory of previous training) (i.e. clears any memory of previous training)
!*/ !*/
void set_max_num_sv (
unsigned long max_num_sv
);
/*!
requires
- max_num_sv > 0
ensures
- #get_max_num_sv() == max_num_sv
- #get_train_count() == 0
(i.e. clears any memory of previous training)
!*/
void set_lambda ( void set_lambda (
scalar_type lambda_ scalar_type lambda_
); );
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment