Commit 734252aa authored by Davis King's avatar Davis King

Made the rbf_network's template argument be a kernel type

instead of a sample type.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402401
parent 4a19c35a
......@@ -16,7 +16,7 @@ namespace dlib
// ------------------------------------------------------------------------------
template <
typename sample_type_
typename K
>
class rbf_network_trainer
{
......@@ -27,36 +27,29 @@ namespace dlib
!*/
public:
typedef radial_basis_kernel<sample_type_> kernel_type;
typedef sample_type_ sample_type;
typedef typename sample_type::type scalar_type;
typedef typename sample_type::mem_manager_type mem_manager_type;
typedef decision_function<kernel_type> trained_function_type;
typedef K kernel_type;
typedef typename kernel_type::scalar_type scalar_type;
typedef typename kernel_type::sample_type sample_type;
typedef typename kernel_type::mem_manager_type mem_manager_type;
typedef decision_function<kernel_type> trained_function_type;
rbf_network_trainer (
) :
gamma(0.1),
tolerance(0.01)
tolerance(0.1)
{
}
void set_gamma (
scalar_type gamma_
void set_kernel (
const kernel_type& k
)
{
// make sure requires clause is not broken
DLIB_ASSERT(gamma_ > 0,
"\tvoid rbf_network_trainer::set_gamma(gamma_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t gamma: " << gamma_
);
gamma = gamma_;
kernel = k;
}
const scalar_type get_gamma (
const kernel_type& get_kernel (
) const
{
return gamma;
{
return kernel;
}
void set_tolerance (
......@@ -88,7 +81,7 @@ namespace dlib
rbf_network_trainer& item
)
{
exchange(gamma, item.gamma);
exchange(kernel, item.kernel);
exchange(tolerance, item.tolerance);
}
......@@ -118,7 +111,6 @@ namespace dlib
);
// first run all the sampes through a kcentroid object to find the rbf centers
const kernel_type kernel(gamma);
kcentroid<kernel_type> kc(kernel,tolerance);
for (long i = 0; i < x.size(); ++i)
{
......@@ -154,7 +146,7 @@ namespace dlib
}
scalar_type gamma;
kernel_type kernel;
scalar_type tolerance;
}; // end of class rbf_network_trainer
......
......@@ -14,17 +14,19 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename sample_type_
typename K
>
class rbf_network_trainer
{
/*!
REQUIREMENTS ON sample_type_
is a dlib::matrix type
REQUIREMENTS ON K
is a kernel function object as defined in dlib/svm/kernel_abstract.h
(since this is supposed to be a RBF network it is probably reasonable
to use some sort of radial basis kernel)
INITIAL VALUE
- get_gamma() == 0.1
- get_tolerance() == 0.01
- get_tolerance() == 0.1
WHAT THIS OBJECT REPRESENTS
This object implements a trainer for an radial basis function network.
......@@ -34,11 +36,11 @@ namespace dlib
about RBF networks.
!*/
public:
typedef radial_basis_kernel<sample_type_> kernel_type;
typedef sample_type_ sample_type;
typedef typename sample_type::type scalar_type;
typedef typename sample_type::mem_manager_type mem_manager_type;
typedef decision_function<kernel_type> trained_function_type;
typedef K kernel_type;
typedef typename kernel_type::scalar_type scalar_type;
typedef typename kernel_type::sample_type sample_type;
typedef typename kernel_type::mem_manager_type mem_manager_type;
typedef decision_function<kernel_type> trained_function_type;
rbf_network_trainer (
);
......@@ -47,22 +49,19 @@ namespace dlib
- this object is properly initialized
!*/
void set_gamma (
scalar_type gamma
void set_kernel (
const kernel_type& k
);
/*!
requires
- gamma > 0
ensures
- #get_gamma() == gamma
- #get_kernel() == k
!*/
const scalar_type get_gamma (
) const
const kernel_type& get_kernel (
) const;
/*!
ensures
- returns the gamma argument used in the radial_basis_kernel used
to represent each node in an RBF network.
- returns a copy of the kernel function in use by this object
!*/
void set_tolerance (
......@@ -102,12 +101,7 @@ namespace dlib
(i.e. x and y are both column vectors of the same length)
ensures
- trains a RBF network given the training samples in x and
labels in y.
- returns a decision function F with the following properties:
- if (new_x is a sample predicted have +1 label) then
- F(new_x) >= 0
- else
- F(new_x) < 0
labels in y and returns the resulting decision_function
throws
- std::bad_alloc
!*/
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment