Commit 734252aa authored by Davis King's avatar Davis King

Made the rbf_network's template argument be a kernel type

instead of a sample type.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402401
parent 4a19c35a
...@@ -16,7 +16,7 @@ namespace dlib ...@@ -16,7 +16,7 @@ namespace dlib
// ------------------------------------------------------------------------------ // ------------------------------------------------------------------------------
template < template <
typename sample_type_ typename K
> >
class rbf_network_trainer class rbf_network_trainer
{ {
...@@ -27,36 +27,29 @@ namespace dlib ...@@ -27,36 +27,29 @@ namespace dlib
!*/ !*/
public: public:
typedef radial_basis_kernel<sample_type_> kernel_type; typedef K kernel_type;
typedef sample_type_ sample_type; typedef typename kernel_type::scalar_type scalar_type;
typedef typename sample_type::type scalar_type; typedef typename kernel_type::sample_type sample_type;
typedef typename sample_type::mem_manager_type mem_manager_type; typedef typename kernel_type::mem_manager_type mem_manager_type;
typedef decision_function<kernel_type> trained_function_type; typedef decision_function<kernel_type> trained_function_type;
rbf_network_trainer ( rbf_network_trainer (
) : ) :
gamma(0.1), tolerance(0.1)
tolerance(0.01)
{ {
} }
void set_gamma ( void set_kernel (
scalar_type gamma_ const kernel_type& k
) )
{ {
// make sure requires clause is not broken kernel = k;
DLIB_ASSERT(gamma_ > 0,
"\tvoid rbf_network_trainer::set_gamma(gamma_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t gamma: " << gamma_
);
gamma = gamma_;
} }
const scalar_type get_gamma ( const kernel_type& get_kernel (
) const ) const
{ {
return gamma; return kernel;
} }
void set_tolerance ( void set_tolerance (
...@@ -88,7 +81,7 @@ namespace dlib ...@@ -88,7 +81,7 @@ namespace dlib
rbf_network_trainer& item rbf_network_trainer& item
) )
{ {
exchange(gamma, item.gamma); exchange(kernel, item.kernel);
exchange(tolerance, item.tolerance); exchange(tolerance, item.tolerance);
} }
...@@ -118,7 +111,6 @@ namespace dlib ...@@ -118,7 +111,6 @@ namespace dlib
); );
// first run all the sampes through a kcentroid object to find the rbf centers // first run all the sampes through a kcentroid object to find the rbf centers
const kernel_type kernel(gamma);
kcentroid<kernel_type> kc(kernel,tolerance); kcentroid<kernel_type> kc(kernel,tolerance);
for (long i = 0; i < x.size(); ++i) for (long i = 0; i < x.size(); ++i)
{ {
...@@ -154,7 +146,7 @@ namespace dlib ...@@ -154,7 +146,7 @@ namespace dlib
} }
scalar_type gamma; kernel_type kernel;
scalar_type tolerance; scalar_type tolerance;
}; // end of class rbf_network_trainer }; // end of class rbf_network_trainer
......
...@@ -14,17 +14,19 @@ namespace dlib ...@@ -14,17 +14,19 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
typename sample_type_ typename K
> >
class rbf_network_trainer class rbf_network_trainer
{ {
/*! /*!
REQUIREMENTS ON sample_type_ REQUIREMENTS ON K
is a dlib::matrix type is a kernel function object as defined in dlib/svm/kernel_abstract.h
(since this is supposed to be a RBF network it is probably reasonable
to use some sort of radial basis kernel)
INITIAL VALUE INITIAL VALUE
- get_gamma() == 0.1 - get_gamma() == 0.1
- get_tolerance() == 0.01 - get_tolerance() == 0.1
WHAT THIS OBJECT REPRESENTS WHAT THIS OBJECT REPRESENTS
This object implements a trainer for an radial basis function network. This object implements a trainer for an radial basis function network.
...@@ -34,10 +36,10 @@ namespace dlib ...@@ -34,10 +36,10 @@ namespace dlib
about RBF networks. about RBF networks.
!*/ !*/
public: public:
typedef radial_basis_kernel<sample_type_> kernel_type; typedef K kernel_type;
typedef sample_type_ sample_type; typedef typename kernel_type::scalar_type scalar_type;
typedef typename sample_type::type scalar_type; typedef typename kernel_type::sample_type sample_type;
typedef typename sample_type::mem_manager_type mem_manager_type; typedef typename kernel_type::mem_manager_type mem_manager_type;
typedef decision_function<kernel_type> trained_function_type; typedef decision_function<kernel_type> trained_function_type;
rbf_network_trainer ( rbf_network_trainer (
...@@ -47,22 +49,19 @@ namespace dlib ...@@ -47,22 +49,19 @@ namespace dlib
- this object is properly initialized - this object is properly initialized
!*/ !*/
void set_gamma ( void set_kernel (
scalar_type gamma const kernel_type& k
); );
/*! /*!
requires
- gamma > 0
ensures ensures
- #get_gamma() == gamma - #get_kernel() == k
!*/ !*/
const scalar_type get_gamma ( const kernel_type& get_kernel (
) const ) const;
/*! /*!
ensures ensures
- returns the gamma argument used in the radial_basis_kernel used - returns a copy of the kernel function in use by this object
to represent each node in an RBF network.
!*/ !*/
void set_tolerance ( void set_tolerance (
...@@ -102,12 +101,7 @@ namespace dlib ...@@ -102,12 +101,7 @@ namespace dlib
(i.e. x and y are both column vectors of the same length) (i.e. x and y are both column vectors of the same length)
ensures ensures
- trains a RBF network given the training samples in x and - trains a RBF network given the training samples in x and
labels in y. labels in y and returns the resulting decision_function
- returns a decision function F with the following properties:
- if (new_x is a sample predicted have +1 label) then
- F(new_x) >= 0
- else
- F(new_x) < 0
throws throws
- std::bad_alloc - std::bad_alloc
!*/ !*/
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment