Commit 577ef82d authored by Davis King's avatar Davis King

Made this function capable of accepting anything that can be converted to

a matrix via vector_to_matrix()

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403891
parent 0d1e82e7
...@@ -45,7 +45,6 @@ namespace dlib ...@@ -45,7 +45,6 @@ namespace dlib
// Build the inverse matrix. This is basically a pseudo-inverse. // Build the inverse matrix. This is basically a pseudo-inverse.
return make_symmetric(eig.get_pseudo_v()*diagm(reciprocal(vals))*trans(eig.get_pseudo_v())); return make_symmetric(eig.get_pseudo_v()*diagm(reciprocal(vals))*trans(eig.get_pseudo_v()));
} }
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -55,12 +54,12 @@ namespace dlib ...@@ -55,12 +54,12 @@ namespace dlib
typename vect2_type, typename vect2_type,
typename vect3_type typename vect3_type
> >
const std::vector<typename kernel_type::sample_type> sort_basis_vectors ( const std::vector<typename kernel_type::sample_type> sort_basis_vectors_impl (
const kernel_type& kern, const kernel_type& kern,
const vect1_type& samples, const vect1_type& samples,
const vect2_type& labels, const vect2_type& labels,
const vect3_type& basis, const vect3_type& basis,
double eps = 0.99 double eps
) )
{ {
DLIB_ASSERT(is_binary_classification_problem(samples, labels) && DLIB_ASSERT(is_binary_classification_problem(samples, labels) &&
...@@ -88,16 +87,16 @@ namespace dlib ...@@ -88,16 +87,16 @@ namespace dlib
// compute the covariance matrix and the means of the two classes. // compute the covariance matrix and the means of the two classes.
for (unsigned long i = 0; i < samples.size(); ++i) for (unsigned long i = 0; i < samples.size(); ++i)
{ {
temp = kernel_matrix(kern, basis, samples[i]); temp = kernel_matrix(kern, basis, samples(i));
cov.add(temp); cov.add(temp);
if (labels[i] > 0) if (labels(i) > 0)
c1_mean += temp; c1_mean += temp;
else else
c2_mean += temp; c2_mean += temp;
} }
c1_mean /= sum(vector_to_matrix(labels) > 0); c1_mean /= sum(labels > 0);
c2_mean /= sum(vector_to_matrix(labels) < 0); c2_mean /= sum(labels < 0);
delta = c1_mean - c2_mean; delta = c1_mean - c2_mean;
...@@ -185,12 +184,37 @@ namespace dlib ...@@ -185,12 +184,37 @@ namespace dlib
{ {
// Note that we load sorted_basis backwards so that the most important // Note that we load sorted_basis backwards so that the most important
// basis elements come first. // basis elements come first.
sorted_basis[i] = vector_to_matrix(basis)(best_total_perm(basis.size()-i-1)); sorted_basis[i] = basis(best_total_perm(basis.size()-i-1));
} }
return sorted_basis; return sorted_basis;
} }
}
// ----------------------------------------------------------------------------------------
template <
typename kernel_type,
typename vect1_type,
typename vect2_type,
typename vect3_type
>
const std::vector<typename kernel_type::sample_type> sort_basis_vectors (
const kernel_type& kern,
const vect1_type& samples,
const vect2_type& labels,
const vect3_type& basis,
double eps = 0.99
)
{
return bs_impl::sort_basis_vectors_impl(kern,
vector_to_matrix(samples),
vector_to_matrix(labels),
vector_to_matrix(basis),
eps);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment