Commit 77036870 authored by Davis King's avatar Davis King

- Made the spec for the rank_features() function a little more clear. Also made

    the implementation do recursive feature elimination when the user tries to
    rank all the features.
  - The report format that comes out of the rank_features() function is now
    also slightly different.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402541
parent a708705f
......@@ -22,10 +22,13 @@ namespace dlib
matrix<typename kernel_type::scalar_type,0,2,typename kernel_type::mem_manager_type> rank_features_impl (
const kcentroid<kernel_type>& kc,
const sample_matrix_type& samples,
const label_matrix_type& labels,
const long num_features
const label_matrix_type& labels
)
{
/*
This function ranks features by doing recursive feature elimination
*/
typedef typename kernel_type::scalar_type scalar_type;
typedef typename kernel_type::sample_type sample_type;
typedef typename kernel_type::mem_manager_type mm;
......@@ -36,24 +39,22 @@ namespace dlib
// make sure requires clause is not broken
DLIB_ASSERT(samples.nc() == 1 && labels.nc() == 1 && samples.size() == labels.size() &&
samples.size() > 0 && num_features > 0,
samples.size() > 0,
"\tmatrix rank_features()"
<< "\n\t you have given invalid arguments to this function"
<< "\n\t samples.nc(): " << samples.nc()
<< "\n\t labels.nc(): " << labels.nc()
<< "\n\t samples.size(): " << samples.size()
<< "\n\t labels.size(): " << labels.size()
<< "\n\t num_features: " << num_features
);
#ifdef ENABLE_ASSERTS
for (long i = 0; i < samples.size(); ++i)
{
DLIB_ASSERT(samples(i).nc() == 1 && num_features <= samples(i).nr() &&
DLIB_ASSERT(samples(i).nc() == 1 &&
samples(0).nr() == samples(i).nr(),
"\tmatrix rank_features()"
<< "\n\t you have given invalid arguments to this function"
<< "\n\t num_features: " << num_features
<< "\n\t samples(i).nc(): " << samples(i).nc()
<< "\n\t samples(i).nr(): " << samples(i).nr()
<< "\n\t samples(0).nr(): " << samples(0).nr()
......@@ -62,13 +63,36 @@ namespace dlib
#endif
matrix<scalar_type,0,2,mm> results(num_features, 2);
matrix<scalar_type,0,2,mm> results(samples(0).nr(), 2);
matrix<scalar_type,sample_matrix_type::type::NR,1,mm> mask(samples(0).nr());
set_all_elements(mask,0);
set_all_elements(mask,1);
// figure out what the separation is between the two centroids when all the features are
// present.
scalar_type first_separation;
{
kcentroid<kernel_type> c1(kc);
kcentroid<kernel_type> c2(kc);
// find the centers of each class
for (long s = 0; s < samples.size(); ++s)
{
if (labels(s) < 0)
{
c1.train(samples(s));
}
else
{
c2.train(samples(s));
}
}
first_separation = c1(c2);
}
using namespace std;
for (long i = 0; i < results.nr(); ++i)
for (long i = results.nr()-1; i >= 0; --i)
{
long worst_feature_idx = 0;
scalar_type worst_feature_score = -std::numeric_limits<scalar_type>::infinity();
......@@ -77,14 +101,14 @@ namespace dlib
for (long j = 0; j < mask.size(); ++j)
{
// skip features we have already removed
if (mask(j) == 1)
if (mask(j) == 0)
continue;
kcentroid<kernel_type> c1(kc);
kcentroid<kernel_type> c2(kc);
// temporarily remove this feature from the working set of features
mask(j) = 1;
mask(j) = 0;
// find the centers of each class
for (long s = 0; s < samples.size(); ++s)
......@@ -111,23 +135,140 @@ namespace dlib
}
// add this feature back to the working set of features
mask(j) = 0;
mask(j) = 1;
}
// now that we know what the next worst feature is record it
mask(worst_feature_idx) = 1;
mask(worst_feature_idx) = 0;
results(i,0) = worst_feature_idx;
results(i,1) = worst_feature_score;
}
// now normalize the results
set_colm(results,1) = colm(results,1)/max(colm(results,1));
for (long i = results.nr()-1; i > 0; --i)
const scalar_type max_separation = std::max(max(colm(results,1)), first_separation);
set_colm(results,1) = colm(results,1)/max_separation;
for (long r = 0; r < results.nr()-1; ++r)
{
results(r,1) = results(r+1,1);
}
results(results.nr()-1,1) = first_separation/max_separation;
return results;
}
// ----------------------------------------------------------------------------------------
template <
typename kernel_type,
typename sample_matrix_type,
typename label_matrix_type
>
matrix<typename kernel_type::scalar_type,0,2,typename kernel_type::mem_manager_type> rank_features (
const kcentroid<kernel_type>& kc,
const sample_matrix_type& samples,
const label_matrix_type& labels
)
{
return rank_features_impl(kc, vector_to_matrix(samples), vector_to_matrix(labels));
}
// ----------------------------------------------------------------------------------------
template <
typename kernel_type,
typename sample_matrix_type,
typename label_matrix_type
>
matrix<typename kernel_type::scalar_type,0,2,typename kernel_type::mem_manager_type> rank_features_impl (
const kcentroid<kernel_type>& kc,
const sample_matrix_type& samples,
const label_matrix_type& labels,
const long num_features
)
{
/*
This function ranks features by doing recursive feature addition
*/
typedef typename kernel_type::scalar_type scalar_type;
typedef typename kernel_type::sample_type sample_type;
typedef typename kernel_type::mem_manager_type mm;
// make sure requires clause is not broken
DLIB_ASSERT(is_binary_classification_problem(samples, labels) == true,
"\tmatrix rank_features()"
<< "\n\t you have given invalid arguments to this function"
);
DLIB_ASSERT(0 < num_features && num_features <= samples(0).nr(),
"\tmatrix rank_features()"
<< "\n\t you have given invalid arguments to this function"
<< "\n\t num_features: " << num_features
<< "\n\t samples(0).nr(): " << samples(0).nr()
);
matrix<scalar_type,0,2,mm> results(num_features, 2);
matrix<scalar_type,sample_matrix_type::type::NR,1,mm> mask(samples(0).nr());
set_all_elements(mask,0);
using namespace std;
for (long i = 0; i < results.nr(); ++i)
{
results(i,1) -= results(i-1,1);
long best_feature_idx = 0;
scalar_type best_feature_score = -std::numeric_limits<scalar_type>::infinity();
// figure out which feature to add next
for (long j = 0; j < mask.size(); ++j)
{
// skip features we have already added
if (mask(j) == 1)
continue;
kcentroid<kernel_type> c1(kc);
kcentroid<kernel_type> c2(kc);
// temporarily add this feature to the working set of features
mask(j) = 1;
// find the centers of each class
for (long s = 0; s < samples.size(); ++s)
{
if (labels(s) < 0)
{
c1.train(pointwise_multiply(samples(s),mask));
}
else
{
c2.train(pointwise_multiply(samples(s),mask));
}
}
// find the distance between the two centroids and use that
// as the score
const double score = c1(c2);
if (score > best_feature_score)
{
best_feature_score = score;
best_feature_idx = j;
}
// take this feature back out of the working set of features
mask(j) = 0;
}
// now that we know what the next best feature is record it
mask(best_feature_idx) = 1;
results(i,0) = best_feature_idx;
results(i,1) = best_feature_score;
}
// now normalize the results
set_colm(results,1) = colm(results,1)/max(colm(results,1));
return results;
}
......@@ -145,7 +286,15 @@ namespace dlib
const long num_features
)
{
return rank_features_impl(kc, vector_to_matrix(samples), vector_to_matrix(labels), num_features);
if (vector_to_matrix(samples).nr() > 0 && num_features == vector_to_matrix(samples)(0).nr())
{
// if we are going to rank them all then might as well do the recursive feature elimination version
return rank_features_impl(kc, vector_to_matrix(samples), vector_to_matrix(labels));
}
else
{
return rank_features_impl(kc, vector_to_matrix(samples), vector_to_matrix(labels), num_features);
}
}
// ----------------------------------------------------------------------------------------
......
......@@ -6,6 +6,7 @@
#include <vector>
#include <limits>
#include "svm_abstract.h"
#include "kcentroid_abstract.h"
#include "../is_kind.h"
......@@ -23,35 +24,23 @@ namespace dlib
const kcentroid<kernel_type>& kc,
const sample_matrix_type& samples,
const label_matrix_type& labels,
const long num_features
const long num_features = samples(0).nr()
);
/*!
requires
- vector_to_matrix(samples) == a valid matrix object
- vector_to_matrix(samples(0)) == a valid matrix object
- vector_to_matrix(labels) == a valid matrix object
(i.e. the 3 above things must either be dlib::matrix objects or be
convertable to them via the vector_to_matrix() function)
- samples.nc() == 1 && labels.nc() == 1
(i.e. samples and labels must be column vectors)
- samples.size() == labels.size()
- samples.size() > 0
- for all i < samples.size()
- 0 < num_features <= samples(i).nr()
- samples(i).nc() = 1
- i.e. samples must contain column vectors of equal length
and num_features must be less than the size of these column vectors
- is_binary_classification_problem(samples, labels) == true
- kc.train(samples(0)) must be a valid expression. This means that
kc must use a kernel type that is capable of operating on the
contents of the samples matrix
- 0 < num_features <= samples(0).nr()
ensures
- Let Class1 denote the centroid of all the samples with labels that are < 0
- Let Class2 denote the centroid of all the samples with labels that are > 0
- finds a ranking of the top num_features best features. This function
does this by computing the distance between the centroid of the Class1
- finds a ranking of the features where the best features come first. This
function does this by computing the distance between the centroid of the Class1
samples and the Class2 samples in kernel defined feature space.
Good features are then ones that result in the biggest separation between
the two centroids of Class1 and Class2
the two centroids of Class1 and Class2.
- Uses the kc object to compute the centroids of the two classes
- returns a ranking matrix R where:
- R.nr() == num_features
......@@ -60,10 +49,8 @@ namespace dlib
(e.g. samples(n)(R(0,0)) is the best feature from sample(n) and
samples(n)(R(1,0)) is the second best, samples(n)(R(2,0)) the
third best and so on)
- R(i,1) == a number that indicates how much the feature R(i,0) contributes
to the separation of the Class1 and Class2 centroids when it
is added into the feature set defined by R(0,0), R(1,0), R(2,0), up to
R(i-1,0).
- R(i,1) == a number that indicates how much separation exists between
the two centroids when features 0 through i are used.
!*/
// ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment