Commit 0fc7ed31 authored by Davis King's avatar Davis King

Added nearest_center()

parent 4d20056a
......@@ -364,12 +364,12 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename vector_type,
typename array_type,
typename sample_type,
typename alloc
>
void find_clusters_using_kmeans (
const vector_type& samples,
const array_type& samples,
std::vector<sample_type, alloc>& centers,
unsigned long max_iter = 1000
)
......@@ -461,6 +461,40 @@ namespace dlib
}
// ----------------------------------------------------------------------------------------
template <
typename array_type,
typename EXP
>
unsigned long nearest_center (
const array_type& centers,
const matrix_exp<EXP>& sample
)
{
// make sure requires clause is not broken
DLIB_ASSERT(centers.size() > 0 && samples.size() > 0 && is_vector(sample),
"\t unsigned long nearest_center()"
<< "\n\t You have given invalid inputs to this function."
<< "\n\t centers.size(): " << centers.size()
<< "\n\t sample.size(): " << sample.size()
<< "\n\t is_vector(sample): " << is_vector(sample)
);
double best_dist = length_squared(centers[0] - sample);
unsigned long best_idx = 0;
for (unsigned long i = 1; i < centers.size(); ++i)
{
const double dist = length_squared(centers[i] - sample);
if (dist < best_dist)
{
best_dist = dist;
best_idx = i;
}
}
return best_idx;
}
// ----------------------------------------------------------------------------------------
}
......
......@@ -240,12 +240,12 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename vector_type,
typename array_type,
typename sample_type,
typename alloc
>
void find_clusters_using_kmeans (
const vector_type& samples,
const array_type& samples,
std::vector<sample_type, alloc>& centers,
unsigned long max_iter = 1000
);
......@@ -255,7 +255,7 @@ namespace dlib
- samples == a bunch of row or column vectors and they all must be of the
same length.
- centers.size() > 0
- vector_type == something with an interface compatible with std::vector
- array_type == something with an interface compatible with std::vector
and it must contain row or column vectors capable of being stored in
sample_type objects
- sample_type == a dlib::matrix capable of representing vectors
......@@ -267,6 +267,30 @@ namespace dlib
terminates.
!*/
// ----------------------------------------------------------------------------------------
template <
typename array_type,
typename EXP
>
unsigned long nearest_center (
const array_type& centers,
const matrix_exp<EXP>& sample
);
/*!
requires
- centers.size() > 0
- sample.size() > 0
- is_vector(sample) == true
- centers must be an array of vectors such that the following expression is
valid: length_squared(sample - centers[0]). (e.g. centers could be a
std::vector of matrix objects holding column vectors).
ensures
- returns the index that identifies the element of centers that is nearest to
sample. That is, returns a number IDX such that centers[IDX] is the element
of centers that minimizes length(centers[IDX]-sample).
!*/
// ----------------------------------------------------------------------------------------
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment