Commit a01f495a authored by Davis King's avatar Davis King

Added wrapper rank_features() functions to help the user

pick reasonable default inputs to this function.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403234
parent 0a7791df
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include "feature_ranking_abstract.h" #include "feature_ranking_abstract.h"
#include "kcentroid.h" #include "kcentroid.h"
#include "../optimization.h"
#include <iostream>
namespace dlib namespace dlib
{ {
...@@ -274,6 +276,161 @@ namespace dlib ...@@ -274,6 +276,161 @@ namespace dlib
} }
} }
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
namespace rank_features_helpers
{
template <
typename K,
typename sample_matrix_type,
typename label_matrix_type
>
typename K::scalar_type centroid_gap (
const kcentroid<K>& kc,
const sample_matrix_type& samples,
const label_matrix_type& labels
)
{
kcentroid<K> kc1(kc);
kcentroid<K> kc2(kc);
// toss all the samples into our kcentroids
for (long i = 0; i < samples.size(); ++i)
{
if (labels(i) > 0)
kc1.train(samples(i));
else
kc2.train(samples(i));
}
// now return the separation between the mean of these two centroids
return kc1(kc2);
}
template <
typename sample_matrix_type,
typename label_matrix_type
>
class test
{
typedef typename sample_matrix_type::type sample_type;
typedef typename sample_type::type scalar_type;
typedef typename sample_type::mem_manager_type mem_manager_type;
public:
test (
const sample_matrix_type& samples_,
const label_matrix_type& labels_,
unsigned long num_sv_,
bool verbose_
) : samples(samples_), labels(labels_), num_sv(num_sv_), verbose(verbose_)
{
}
double operator() (
const double gamma
) const
{
using namespace std;
if (verbose)
{
cout << "\rChecking goodness of gamma = " << gamma << ". " << flush;
}
typedef radial_basis_kernel<sample_type> kernel_type;
// make a kcentroid and find out what the gap is at the current gamma
kcentroid<kernel_type> kc(kernel_type(gamma), 0.0001, num_sv);
scalar_type temp = centroid_gap(kc, samples, labels);
if (verbose)
{
cout << "Goodness = " << temp << " " << flush;
}
return temp;
}
const sample_matrix_type& samples;
const label_matrix_type& labels;
unsigned long num_sv;
bool verbose;
};
template <
typename sample_matrix_type,
typename label_matrix_type
>
matrix<double,0,2> rank_features_rbf_impl (
const sample_matrix_type& samples,
const label_matrix_type& labels,
unsigned long num_sv,
bool verbose
)
{
typedef typename sample_matrix_type::type sample_type;
using namespace std;
if (verbose)
{
cout << endl;
}
// The first thing to do is to estimate what a good gamma is. Since the feature ranking
// works by ranking features by how much they separate the centroids of the two classes
// we will pick the gamma to use by finding the one that best separates the two classes.
test<sample_matrix_type, label_matrix_type> funct(samples, labels, num_sv, verbose);
double best_gamma = find_max_single_variable(funct, 1.0, 1e-8, 1000, 1e-3, 50).first;
typedef radial_basis_kernel<sample_type> kernel_type;
if (verbose)
{
cout << "\nNow doing feature ranking using a gamma of " << best_gamma << endl;
}
// now just call the normal rank_features function and return whatever it says
kcentroid<kernel_type> kc(kernel_type(best_gamma), 0.0001, num_sv);
return matrix_cast<double>(rank_features(kc, samples, labels));
}
}
// ----------------------------------------------------------------------------------------
template <
typename sample_matrix_type,
typename label_matrix_type
>
matrix<double,0,2> rank_features_rbf (
const sample_matrix_type& samples,
const label_matrix_type& labels,
unsigned long num_sv = 40
)
{
return rank_features_helpers::rank_features_rbf_impl(vector_to_matrix(samples),
vector_to_matrix(labels),
num_sv,
false);
}
// ----------------------------------------------------------------------------------------
template <
typename sample_matrix_type,
typename label_matrix_type
>
matrix<double,0,2> verbose_rank_features_rbf (
const sample_matrix_type& samples,
const label_matrix_type& labels,
unsigned long num_sv = 40
)
{
return rank_features_helpers::rank_features_rbf_impl(vector_to_matrix(samples),
vector_to_matrix(labels),
num_sv,
true);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
...@@ -55,6 +55,54 @@ namespace dlib ...@@ -55,6 +55,54 @@ namespace dlib
the two centroids when features 0 through i are used. the two centroids when features 0 through i are used.
!*/ !*/
// ----------------------------------------------------------------------------------------
template <
typename sample_matrix_type,
typename label_matrix_type
>
matrix<double,0,2> rank_features_rbf (
const sample_matrix_type& samples,
const label_matrix_type& labels,
unsigned long num_sv = 40
);
/*!
requires
- num_sv > 0
- is_binary_classification_problem(samples, labels) == true
ensures
- This function just calls the above rank_features() function but uses the
radial_basis_kernel and automatically picks a gamma parameter for you.
It also sets the kcentroid up to use num_sv dictionary vectors. Finally, it
tells rank_features() to rank all the features.
- The return value from this function is the matrix returned by rank_features()
!*/
// ----------------------------------------------------------------------------------------
template <
typename sample_matrix_type,
typename label_matrix_type
>
matrix<double,0,2> verbose_rank_features_rbf (
const sample_matrix_type& samples,
const label_matrix_type& labels,
unsigned long num_sv = 40
);
/*!
requires
- num_sv > 0
- is_binary_classification_problem(samples, labels) == true
ensures
- This function just calls the above rank_features() function but uses the
radial_basis_kernel and automatically picks a gamma parameter for you.
It also sets the kcentroid up to use num_sv dictionary vectors. Finally, it
tells rank_features() to rank all the features.
- The return value from this function is the matrix returned by rank_features()
- This function is verbose in the sense that it will print status messages to
standard out during its processing.
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment