Commit c8d9f20b authored by Davis King's avatar Davis King

Changed the rbf_network_trainer to use the linearly_independent_subset_finder

to find centers.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402413
parent ec316ccb
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "../matrix.h" #include "../matrix.h"
#include "rbf_network_abstract.h" #include "rbf_network_abstract.h"
#include "kernel.h" #include "kernel.h"
#include "kcentroid.h" #include "linearly_independent_subset_finder.h"
#include "function.h" #include "function.h"
#include "../algs.h" #include "../algs.h"
...@@ -23,7 +23,8 @@ namespace dlib ...@@ -23,7 +23,8 @@ namespace dlib
/*! /*!
This is an implemenation of an RBF network trainer that follows This is an implemenation of an RBF network trainer that follows
the directions right off Wikipedia basically. So nothing the directions right off Wikipedia basically. So nothing
particularly fancy. particularly fancy. Although the way the centers are selected
is somewhat unique.
!*/ !*/
public: public:
...@@ -35,7 +36,7 @@ namespace dlib ...@@ -35,7 +36,7 @@ namespace dlib
rbf_network_trainer ( rbf_network_trainer (
) : ) :
tolerance(0.1) num_centers(10)
{ {
} }
...@@ -52,17 +53,17 @@ namespace dlib ...@@ -52,17 +53,17 @@ namespace dlib
return kernel; return kernel;
} }
void set_tolerance ( void set_num_centers (
const scalar_type& tol const unsigned long num
) )
{ {
tolerance = tol; num_centers = num;
} }
const scalar_type& get_tolerance ( const unsigned long get_num_centers (
) const ) const
{ {
return tolerance; return num_centers;
} }
template < template <
...@@ -82,7 +83,7 @@ namespace dlib ...@@ -82,7 +83,7 @@ namespace dlib
) )
{ {
exchange(kernel, item.kernel); exchange(kernel, item.kernel);
exchange(tolerance, item.tolerance); exchange(num_centers, item.num_centers);
} }
private: private:
...@@ -99,6 +100,7 @@ namespace dlib ...@@ -99,6 +100,7 @@ namespace dlib
) const ) const
{ {
typedef typename decision_function<kernel_type>::scalar_vector_type scalar_vector_type; typedef typename decision_function<kernel_type>::scalar_vector_type scalar_vector_type;
typedef typename decision_function<kernel_type>::sample_vector_type sample_vector_type;
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(x.nr() > 1 && x.nr() == y.nr() && x.nc() == 1 && y.nc() == 1, DLIB_ASSERT(x.nr() > 1 && x.nr() == y.nr() && x.nc() == 1 && y.nc() == 1,
...@@ -110,18 +112,15 @@ namespace dlib ...@@ -110,18 +112,15 @@ namespace dlib
<< "\n\t y.nc(): " << y.nc() << "\n\t y.nc(): " << y.nc()
); );
// first run all the sampes through a kcentroid object to find the rbf centers // use the linearly_independent_subset_finder object to select the centers. So here
kcentroid<kernel_type> kc(kernel,tolerance); // we show it all the data samples so it can find the best centers.
linearly_independent_subset_finder<kernel_type> lisf(kernel, num_centers);
for (long i = 0; i < x.size(); ++i) for (long i = 0; i < x.size(); ++i)
{ {
kc.train(x(i)); lisf.add(x(i));
} }
// now we have a trained kcentroid so lets just extract its results. Note that const long num_centers = lisf.dictionary_size();
// all we want out of the kcentroid is really just the set of support vectors
// it contains so that we can use them as the RBF centers.
distance_function<kernel_type> df(kc.get_distance_function());
const long num_centers = df.support_vectors.nr();
// fill the K matrix with the output of the kernel for all the center and sample point pairs // fill the K matrix with the output of the kernel for all the center and sample point pairs
matrix<scalar_type,0,0,mem_manager_type> K(x.nr(), num_centers+1); matrix<scalar_type,0,0,mem_manager_type> K(x.nr(), num_centers+1);
...@@ -129,7 +128,7 @@ namespace dlib ...@@ -129,7 +128,7 @@ namespace dlib
{ {
for (long c = 0; c < num_centers; ++c) for (long c = 0; c < num_centers; ++c)
{ {
K(r,c) = kernel(x(r), df.support_vectors(c)); K(r,c) = kernel(x(r), lisf[c]);
} }
// This last column of the K matrix takes care of the bias term // This last column of the K matrix takes care of the bias term
K(r,num_centers) = 1; K(r,num_centers) = 1;
...@@ -142,12 +141,12 @@ namespace dlib ...@@ -142,12 +141,12 @@ namespace dlib
return decision_function<kernel_type> (remove_row(weights,num_centers), return decision_function<kernel_type> (remove_row(weights,num_centers),
-weights(num_centers), -weights(num_centers),
kernel, kernel,
df.support_vectors); lisf.get_dictionary());
} }
kernel_type kernel; kernel_type kernel;
scalar_type tolerance; unsigned long num_centers;
}; // end of class rbf_network_trainer }; // end of class rbf_network_trainer
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#undef DLIB_RBf_NETWORK_ABSTRACT_ #undef DLIB_RBf_NETWORK_ABSTRACT_
#ifdef DLIB_RBf_NETWORK_ABSTRACT_ #ifdef DLIB_RBf_NETWORK_ABSTRACT_
#include "../matrix/matrix_abstract.h"
#include "../algs.h" #include "../algs.h"
#include "function_abstract.h" #include "function_abstract.h"
#include "kernel_abstract.h" #include "kernel_abstract.h"
...@@ -25,16 +24,16 @@ namespace dlib ...@@ -25,16 +24,16 @@ namespace dlib
to use some sort of radial basis kernel) to use some sort of radial basis kernel)
INITIAL VALUE INITIAL VALUE
- get_gamma() == 0.1 - get_num_centers() == 10
- get_tolerance() == 0.1
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This object implements a trainer for an radial basis function network. This object implements a trainer for a radial basis function network.
The implementation of this algorithm follows the normal RBF training The implementation of this algorithm follows the normal RBF training
process. For more details see the code or the Wikipedia article process. For more details see the code or the Wikipedia article
about RBF networks. about RBF networks.
!*/ !*/
public: public:
typedef K kernel_type; typedef K kernel_type;
typedef typename kernel_type::scalar_type scalar_type; typedef typename kernel_type::scalar_type scalar_type;
...@@ -64,22 +63,20 @@ namespace dlib ...@@ -64,22 +63,20 @@ namespace dlib
- returns a copy of the kernel function in use by this object - returns a copy of the kernel function in use by this object
!*/ !*/
void set_tolerance ( void set_num_centers (
const scalar_type& tol const unsigned long num_centers
); );
/*! /*!
ensures ensures
- #get_tolerance() == tol - #get_num_centers() == num_centers
!*/ !*/
const scalar_type& get_tolerance ( const unsigned long get_num_centers (
) const; ) const;
/*! /*!
ensures ensures
- returns the tolerance parameter. This parameter controls how many - returns the number of centers (a.k.a. support_vectors in the
RBF centers (a.k.a. support_vectors in the trained decision_function) trained decision_function) you will get when you train this object on data.
you get when you call the train function. A smaller tolerance
results in more centers while a bigger number results in fewer.
!*/ !*/
template < template <
...@@ -118,10 +115,10 @@ namespace dlib ...@@ -118,10 +115,10 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename sample_type> template <typename K>
void swap ( void swap (
rbf_network_trainer<sample_type>& a, rbf_network_trainer<K>& a,
rbf_network_trainer<sample_type>& b rbf_network_trainer<K>& b
) { a.swap(b); } ) { a.swap(b); }
/*! /*!
provides a global swap provides a global swap
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment