Commit fd31f948 authored by Davis King's avatar Davis King

Added a function for performing supervised basis selection.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403888
parent 65625185
......@@ -23,6 +23,7 @@
#include "svm/svm_c_ekm_trainer.h"
#include "svm/simplify_linear_decision_function.h"
#include "svm/krr_trainer.h"
#include "svm/sort_basis_vectors.h"
#endif // DLIB_SVm_HEADER
......
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_SORT_BASIS_VECTORs_H__
#define DLIB_SORT_BASIS_VECTORs_H__
#include <vector>
#include "sort_basis_vectors_abstract.h"
#include "../matrix.h"
#include "../statistics.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
namespace bs_impl
{
template <typename EXP>
typename EXP::matrix_type invert (
const matrix_exp<EXP>& m
)
{
eigenvalue_decomposition<EXP> eig(make_symmetric(m));
typedef typename EXP::type scalar_type;
typedef typename EXP::mem_manager_type mm_type;
matrix<scalar_type,0,1,mm_type> vals = eig.get_real_eigenvalues();
const scalar_type max_eig = max(abs(vals));
const scalar_type thresh = max_eig*std::sqrt(std::numeric_limits<scalar_type>::epsilon());
// Since m might be singular or almost singular we need to do something about
// any very small eigenvalues. So here we set the smallest eigenvalues to
// be equal to a large value to make the inversion stable. We can't just set
// them to zero like in a normal pseudo-inverse since we want the resulting
// inverse matrix to be full rank.
for (long i = 0; i < vals.size(); ++i)
{
if (std::abs(vals(i)) < thresh)
vals(i) = max_eig;
}
// Build the inverse matrix. This is basically a pseudo-inverse.
return make_symmetric(eig.get_pseudo_v()*diagm(reciprocal(vals))*trans(eig.get_pseudo_v()));
}
}
// ----------------------------------------------------------------------------------------
template <
typename kernel_type,
typename vect1_type,
typename vect2_type,
typename vect3_type
>
const std::vector<typename kernel_type::sample_type> sort_basis_vectors (
const kernel_type& kern,
const vect1_type& samples,
const vect2_type& labels,
const vect3_type& basis,
double eps = 0.99
)
{
DLIB_ASSERT(is_binary_classification_problem(samples, labels) &&
0 < eps && eps <= 1,
"\t void sort_basis_vectors()"
<< "\n\t Invalid arguments were given to this function."
<< "\n\t is_binary_classification_problem(samples, labels): " << is_binary_classification_problem(samples, labels)
<< "\n\t eps: " << eps
);
typedef typename kernel_type::sample_type sample_type;
typedef typename kernel_type::scalar_type scalar_type;
typedef typename kernel_type::mem_manager_type mm_type;
typedef matrix<scalar_type,0,1,mm_type> col_matrix;
typedef matrix<scalar_type,0,0,mm_type> gen_matrix;
col_matrix c1_mean, c2_mean, temp, delta;
col_matrix weights;
running_covariance<gen_matrix> cov;
// compute the covariance matrix and the means of the two classes.
for (unsigned long i = 0; i < samples.size(); ++i)
{
temp = kernel_matrix(kern, basis, samples[i]);
cov.add(temp);
if (labels[i] > 0)
c1_mean += temp;
else
c2_mean += temp;
}
c1_mean /= sum(vector_to_matrix(labels) > 0);
c2_mean /= sum(vector_to_matrix(labels) < 0);
delta = c1_mean - c2_mean;
gen_matrix cov_inv = bs_impl::invert(cov.covariance());
matrix<long,0,1,mm_type> total_perm = trans(range(0, delta.size()-1));
matrix<long,0,1,mm_type> perm = total_perm;
std::vector<std::pair<scalar_type,long> > sorted_feats(delta.size());
long best_size = delta.size();
long misses = 0;
matrix<long,0,1,mm_type> best_total_perm;
// Now we basically find fisher's linear discriminant over and over. Each
// time sorting the features so that the most important ones pile up together.
weights = trans(chol(cov_inv))*delta;
while (true)
{
for (unsigned long i = 0; i < sorted_feats.size(); ++i)
sorted_feats[i] = make_pair(std::abs(weights(i)), i);
std::sort(sorted_feats.begin(), sorted_feats.end());
// make a permutation vector according to the sorting
for (long i = 0; i < perm.size(); ++i)
perm(i) = sorted_feats[i].second;
// Apply the permutation. Doing this gives the same result as permuting all the
// features and then recomputing the delta and cov_inv from scratch.
cov_inv = subm(cov_inv,perm,perm);
delta = rowm(delta,perm);
// Record all the permutations we have done so we will know how the final
// weights match up with the original basis vectors when we are done.
total_perm = rowm(total_perm, perm);
// compute new Fisher weights for sorted features.
weights = trans(chol(cov_inv))*delta;
// Measure how many features it takes to account for eps% of the weights vector.
const scalar_type total_weight = length_squared(weights);
scalar_type weight_accum = 0;
long size = 0;
// figure out how to get eps% of the weights
for (long i = weights.size()-1; i >= 0; --i)
{
++size;
weight_accum += weights(i)*weights(i);
if (weight_accum/total_weight > eps)
break;
}
// loop until the best_size stops dropping
if (size < best_size)
{
misses = 0;
best_size = size;
best_total_perm = total_perm;
}
else
{
++misses;
// Give up once we have had 10 rounds where we didn't find a weights vector with
// a smaller concentration of good features.
if (misses >= 10)
break;
}
}
// make sure best_size isn't zero
if (best_size == 0)
best_size = 1;
std::vector<typename kernel_type::sample_type> sorted_basis;
// permute the basis so that it matches up with the contents of the best weights
sorted_basis.resize(best_size);
for (unsigned long i = 0; i < sorted_basis.size(); ++i)
{
// Note that we load sorted_basis backwards so that the most important
// basis elements come first.
sorted_basis[i] = vector_to_matrix(basis)(best_total_perm(basis.size()-i-1));
}
return sorted_basis;
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_SORT_BASIS_VECTORs_H__
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_SORT_BASIS_VECTORs_ABSTRACT_H__
#ifdef DLIB_SORT_BASIS_VECTORs_ABSTRACT_H__
#include <vector>
#include "../matrix.h"
#include "../statistics.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename kernel_type,
typename vect1_type,
typename vect2_type,
typename vect3_type
>
const std::vector<typename kernel_type::sample_type> sort_basis_vectors (
const kernel_type& kern,
const vect1_type& samples,
const vect2_type& labels,
const vect3_type& basis,
double eps = 0.99
);
/*!
requires
- is_binary_classification_problem(samples, labels)
- 0 < eps <= 1
- kernel_type is a kernel function object as defined in dlib/svm/kernel_abstract.h
It must be capable of operating on the elements of samples and basis.
- vect1_type == a matrix or something convertible to a matrix via vector_to_matrix()
- vect2_type == a matrix or something convertible to a matrix via vector_to_matrix()
- vect3_type == a matrix or something convertible to a matrix via vector_to_matrix()
ensures
- A kernel based learning method ultimately needs to select a set of basis functions
represented by a particular choice of kernel function and a set of basis vectors.
This function attempts to order the elements of basis so that elements which are
most useful in solving the binary classification problem defined by samples and
labels come first.
- In particular, this function returns a std::vector, SB, of sorted basis vectors such that:
- 0 < SB.size() <= basis.size()
- SB will contain elements from basis but they will have been sorted so that
the most useful elements come first (i.e. SB[0] is the most important).
- eps notionally controls how big SB will be. Bigger eps corresponds to a
bigger basis. You can think of it like asking for eps percent of the
discriminating power from the input basis.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_SORT_BASIS_VECTORs_ABSTRACT_H__
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment