Commit 8d2243b0 authored by Davis King's avatar Davis King

Added the ability to use a kernel cache to the batch_trainer object. I also

changed it so that it always calls clear() on the trainer it uses before it
begins training.  This way it always forgets the results of previous trainings.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403143
parent 644bc768
......@@ -10,6 +10,7 @@
#include "kernel.h"
#include "kcentroid.h"
#include <iostream>
#include "../smart_pointers.h"
namespace dlib
{
......@@ -30,6 +31,11 @@ namespace dlib
typedef typename kernel_type::mem_manager_type mem_manager_type;
typedef decision_function<kernel_type> trained_function_type;
template <typename K_>
struct rebind {
typedef svm_pegasos<K_> other;
};
svm_pegasos (
) :
max_sv(40),
......@@ -255,6 +261,22 @@ namespace dlib
svm_pegasos<K>& b
) { a.swap(b); }
// ----------------------------------------------------------------------------------------
template <
typename T,
typename U
>
void replicate_settings (
const svm_pegasos<T>& source,
svm_pegasos<U>& dest
)
{
dest.set_tolerance(source.get_tolerance());
dest.set_lambda(source.get_lambda());
dest.set_max_num_sv(source.get_max_num_sv());
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......@@ -264,6 +286,134 @@ namespace dlib
>
class batch_trainer
{
// ------------------------------------------------------------------------------------
template <
typename K,
typename sample_vector_type
>
class caching_kernel
{
public:
typedef typename K::scalar_type scalar_type;
typedef long sample_type;
//typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type;
caching_kernel () : samples(0), counter(0), counter_threshold(0) {}
caching_kernel (
const K& kern,
const sample_vector_type& samps,
long cache_size_
) : real_kernel(kern), samples(&samps), counter(0)
{
cache_size = std::min<long>(cache_size_, samps.size());
cache.reset(new cache_type);
cache->frequency_of_use.resize(samps.size());
for (unsigned long i = 0; i < samps.size(); ++i)
cache->frequency_of_use[i] = std::make_pair(0, i);
// Set the cache build/rebuild threshold so that we have to have
// as many cache misses as there are entries in the cache before
// we build/rebuild.
counter_threshold = samps.size()*cache_size;
cache->sample_location.assign(samples->size(), -1);
}
scalar_type operator() (
const sample_type& a,
const sample_type& b
) const
{
const long a_loc = cache->sample_location[a];
const long b_loc = cache->sample_location[b];
cache->frequency_of_use[a].first += 1;
cache->frequency_of_use[b].first += 1;
// rebuild the cache every so often
if (counter > counter_threshold )
{
build_cache();
}
if (a_loc != -1)
{
return cache->kernel(a_loc, b);
}
else if (b_loc != -1)
{
return cache->kernel(b_loc, a);
}
else
{
++counter;
return real_kernel((*samples)(a), (*samples)(b));
}
}
bool operator== (
const caching_kernel& item
) const
{
return item.real_kernel == real_kernel &&
item.samples == samples;
}
private:
K real_kernel;
void build_cache (
) const
{
std::sort(cache->frequency_of_use.rbegin(), cache->frequency_of_use.rend());
counter = 0;
cache->kernel.set_size(cache_size, samples->size());
cache->sample_location.assign(samples->size(), -1);
// loop over all the samples in the cache
for (unsigned long i = 0; i < cache_size; ++i)
{
const long cur = cache->frequency_of_use[i].second;
cache->sample_location[cur] = i;
// now populate all possible kernel products with the current sample
for (unsigned long j = 0; j < samples->size(); ++j)
{
cache->kernel(i, j) = real_kernel((*samples)(cur), (*samples)(j));
}
}
// reset the frequency of use metrics
for (unsigned long i = 0; i < samples->size(); ++i)
cache->frequency_of_use[i] = std::make_pair(0, i);
}
struct cache_type
{
matrix<float> kernel;
std::vector<long> sample_location; // where in the cache a sample is. -1 means not in cache
std::vector<std::pair<long,long> > frequency_of_use;
};
const sample_vector_type* samples;
shared_ptr<cache_type> cache;
mutable unsigned long counter;
unsigned long counter_threshold;
long cache_size;
};
// ------------------------------------------------------------------------------------
public:
typedef typename trainer_type::kernel_type kernel_type;
typedef typename trainer_type::scalar_type scalar_type;
......@@ -274,25 +424,35 @@ namespace dlib
batch_trainer (
) :
min_learning_rate(0.1)
min_learning_rate(0.1),
use_cache(false),
cache_size(100)
{
}
batch_trainer (
const trainer_type& trainer_,
const scalar_type min_learning_rate_,
bool verbose_
bool verbose_,
bool use_cache_,
long cache_size_ = 100
) :
trainer(trainer_),
min_learning_rate(min_learning_rate_),
verbose(verbose_)
verbose(verbose_),
use_cache(use_cache_),
cache_size(cache_size_)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < min_learning_rate_,
DLIB_ASSERT(0 < min_learning_rate_ &&
cache_size_ > 0,
"\tbatch_trainer::batch_trainer()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t min_learning_rate_: " << min_learning_rate_
<< "\n\t cache_size_: " << cache_size_
);
trainer.clear();
}
const scalar_type get_min_learning_rate (
......@@ -310,7 +470,10 @@ namespace dlib
const in_scalar_vector_type& y
) const
{
return do_train(vector_to_matrix(x), vector_to_matrix(y));
if (use_cache)
return do_train_cached(vector_to_matrix(x), vector_to_matrix(y));
else
return do_train(vector_to_matrix(x), vector_to_matrix(y));
}
private:
......@@ -344,7 +507,7 @@ namespace dlib
{
if ( (count&0x7FF) == 0)
{
std::cout << "\rrbatch_trainer(): Percent complete: "
std::cout << "\rbatch_trainer(): Percent complete: "
<< 100*min_learning_rate/cur_learning_rate << " " << std::flush;
}
++count;
......@@ -365,10 +528,84 @@ namespace dlib
}
}
template <
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const decision_function<kernel_type> do_train_cached (
const in_sample_vector_type& x,
const in_scalar_vector_type& y
) const
{
typedef typename decision_function<kernel_type>::sample_vector_type sample_vector_type;
typedef typename decision_function<kernel_type>::scalar_vector_type scalar_vector_type;
dlib::rand::kernel_1a rnd;
// make a caching kernel
typedef caching_kernel<kernel_type, in_sample_vector_type> ckernel_type;
ckernel_type ck(trainer.get_kernel(), x, cache_size);
// now rebind the trainer to use the caching kernel
typename trainer_type::template rebind<ckernel_type>::other my_trainer;
my_trainer.set_kernel(ck);
replicate_settings(trainer, my_trainer);
scalar_type cur_learning_rate = min_learning_rate + 10;
unsigned long count = 0;
while (cur_learning_rate > min_learning_rate)
{
const long i = rnd.get_random_32bit_number()%x.size();
// keep feeding the trainer data until its learning rate goes below our threshold
cur_learning_rate = my_trainer.train(i, y(i));
if (verbose)
{
if ( (count&0x7FF) == 0)
{
std::cout << "\rbatch_trainer(): Percent complete: "
<< 100*min_learning_rate/cur_learning_rate << " " << std::flush;
}
++count;
}
}
if (verbose)
{
decision_function<ckernel_type> cached_df;
cached_df = my_trainer.get_decision_function();
std::cout << "\rbatch_trainer(): Percent complete: 100 " << std::endl;
std::cout << " Num sv: " << cached_df.support_vectors.size() << std::endl;
std::cout << " bias: " << cached_df.b << std::endl;
return decision_function<kernel_type> (
cached_df.alpha,
cached_df.b,
trainer.get_kernel(),
rowm(x, cached_df.support_vectors)
);
}
else
{
decision_function<ckernel_type> cached_df;
cached_df = my_trainer.get_decision_function();
return decision_function<kernel_type> (
cached_df.alpha,
cached_df.b,
trainer.get_kernel(),
rowm(x, cached_df.support_vectors)
);
}
}
trainer_type trainer;
scalar_type min_learning_rate;
bool verbose;
bool use_cache;
long cache_size;
}; // end of class batch_trainer
......@@ -380,7 +617,7 @@ namespace dlib
const batch_trainer<trainer_type> batch (
const trainer_type& trainer,
const typename trainer_type::scalar_type min_learning_rate = 0.1
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, false); }
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, false, false); }
// ----------------------------------------------------------------------------------------
......@@ -390,7 +627,29 @@ namespace dlib
const batch_trainer<trainer_type> verbose_batch (
const trainer_type& trainer,
const typename trainer_type::scalar_type min_learning_rate = 0.1
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, true); }
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, true, false); }
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
const batch_trainer<trainer_type> batch_cached (
const trainer_type& trainer,
const typename trainer_type::scalar_type min_learning_rate = 0.1,
long cache_size = 100
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, false, true, cache_size); }
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
const batch_trainer<trainer_type> verbose_batch_cached (
const trainer_type& trainer,
const typename trainer_type::scalar_type min_learning_rate = 0.1,
long cache_size = 100
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, true, true, cache_size); }
// ----------------------------------------------------------------------------------------
......
......@@ -51,6 +51,11 @@ namespace dlib
typedef typename kernel_type::mem_manager_type mem_manager_type;
typedef decision_function<kernel_type> trained_function_type;
template <typename K_>
struct rebind {
typedef svm_pegasos<K_> other;
};
svm_pegasos (
);
/*!
......@@ -272,6 +277,24 @@ namespace dlib
provides serialization support for svm_pegasos objects
!*/
// ----------------------------------------------------------------------------------------
template <
typename T,
typename U
>
void replicate_settings (
const svm_pegasos<T>& source,
svm_pegasos<U>& dest
);
/*!
ensures
- copies all the parameters from the source trainer to the dest trainer.
- #dest.get_tolerance() == source.get_tolerance()
- #dest.get_lambda() == source.get_lambda()
- #dest.get_max_num_sv() == source.get_max_num_sv()
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......@@ -284,6 +307,7 @@ namespace dlib
/*!
REQUIREMENTS ON trainer_type
- trainer_type == some kind of online trainer object (e.g. svm_pegasos)
replicate_settings() must also be defined for the type.
WHAT THIS OBJECT REPRESENTS
This is a trainer object that is meant to wrap online trainer objects
......@@ -313,11 +337,14 @@ namespace dlib
batch_trainer (
const trainer_type& online_trainer,
const scalar_type min_learning_rate_,
bool verbose_
bool verbose_,
bool use_cache_,
long cache_size_ = 100
);
/*!
requires
- min_learning_rate_ > 0
- cache_size_ > 0
ensures
- returns a batch trainer object that uses the given online_trainer object
to train a decision function.
......@@ -325,6 +352,9 @@ namespace dlib
- if (verbose_ == true) then
- this object will output status messages to standard out while
training is under way.
- if (use_cache_ == true) then
- this object will cache up to cache_size_ columns of the kernel
matrix during the training process.
!*/
const scalar_type get_min_learning_rate (
......@@ -364,12 +394,12 @@ namespace dlib
const batch_trainer<trainer_type> batch (
const trainer_type& trainer,
const typename trainer_type::scalar_type min_learning_rate = 0.1
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, false); }
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, false, false); }
/*!
requires
- min_learning_rate > 0
- trainer_type == some kind of online trainer object that creates decision_function
objects (e.g. svm_pegasos)
objects (e.g. svm_pegasos). replicate_settings() must also be defined for the type.
ensures
- returns a batch_trainer object that has been instantiated with the
given arguments.
......@@ -383,12 +413,12 @@ namespace dlib
const batch_trainer<trainer_type> verbose_batch (
const trainer_type& trainer,
const typename trainer_type::scalar_type min_learning_rate = 0.1
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, true); }
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, true, false); }
/*!
requires
- min_learning_rate > 0
- trainer_type == some kind of online trainer object that creates decision_function
objects (e.g. svm_pegasos)
objects (e.g. svm_pegasos). replicate_settings() must also be defined for the type.
ensures
- returns a batch_trainer object that has been instantiated with the
given arguments (and is verbose).
......@@ -396,6 +426,49 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
const batch_trainer<trainer_type> batch_cached (
const trainer_type& trainer,
const typename trainer_type::scalar_type min_learning_rate = 0.1,
long cache_size = 100
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, false, true, cache_size); }
/*!
requires
- min_learning_rate > 0
- cache_size > 0
- trainer_type == some kind of online trainer object that creates decision_function
objects (e.g. svm_pegasos). replicate_settings() must also be defined for the type.
ensures
- returns a batch_trainer object that has been instantiated with the
given arguments (uses a kernel cache).
!*/
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
const batch_trainer<trainer_type> verbose_batch_cached (
const trainer_type& trainer,
const typename trainer_type::scalar_type min_learning_rate = 0.1,
long cache_size = 100
) { return batch_trainer<trainer_type>(trainer, min_learning_rate, true, true, cache_size); }
/*!
requires
- min_learning_rate > 0
- cache_size > 0
- trainer_type == some kind of online trainer object that creates decision_function
objects (e.g. svm_pegasos). replicate_settings() must also be defined for the type.
ensures
- returns a batch_trainer object that has been instantiated with the
given arguments (is verbose and uses a kernel cache).
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_PEGASoS_ABSTRACT_
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment