Commit c4d69295 authored by Davis King's avatar Davis King

Added a simple linear kmeans implementation.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403472
parent e216fe04
......@@ -316,7 +316,7 @@ namespace dlib
// make sure requires clause is not broken
DLIB_CASSERT(num_centers > 1 && 0 <= percentile && percentile < 1 && samples.size() > 1,
DLIB_ASSERT(num_centers > 1 && 0 <= percentile && percentile < 1 && samples.size() > 1,
"\tvoid pick_initial_centers()"
<< "\n\tYou passed invalid arguments to this function"
<< "\n\tnum_centers: " << num_centers
......@@ -361,6 +361,106 @@ namespace dlib
}
// ----------------------------------------------------------------------------------------
template <
typename vector_type,
typename sample_type,
typename alloc
>
void find_clusters_using_kmeans (
const vector_type& samples,
std::vector<sample_type, alloc>& centers,
unsigned long max_iter = 1000
)
{
// make sure requires clause is not broken
DLIB_ASSERT(samples.size() > 0 && centers.size() > 0,
"\tvoid find_clusters_using_kmeans()"
<< "\n\tYou passed invalid arguments to this function"
<< "\n\t samples.size(): " << samples.size()
<< "\n\t centers.size(): " << centers.size()
);
#ifdef ENABLE_ASSERTS
{
const long nr = samples[0].nr();
const long nc = samples[0].nc();
for (unsigned long i = 0; i < samples.size(); ++i)
{
DLIB_ASSERT(is_vector(samples[i]) && samples[i].nr() == nr && samples[i].nc() == nc,
"\tvoid find_clusters_using_kmeans()"
<< "\n\t You passed invalid arguments to this function"
<< "\n\t is_vector(samples[i]): " << is_vector(samples[i])
<< "\n\t samples[i].nr(): " << samples[i].nr()
<< "\n\t nr: " << nr
<< "\n\t samples[i].nc(): " << samples[i].nc()
<< "\n\t nc: " << nc
<< "\n\t i: " << i
);
}
}
#endif
typedef typename sample_type::type scalar_type;
sample_type zero(centers[0]);
zero = 0;
std::vector<unsigned long, alloc> center_element_count;
// tells which center a sample belongs to
std::vector<unsigned long, alloc> assignments(samples.size(), samples.size());
unsigned long iter = 0;
bool centers_changed = true;
while (centers_changed && iter < max_iter)
{
++iter;
centers_changed = false;
center_element_count.assign(centers.size(), 0);
// loop over each sample and see which center it is closest to
for (unsigned long i = 0; i < samples.size(); ++i)
{
// find the best center for sample[i]
scalar_type best_dist = std::numeric_limits<scalar_type>::max();
unsigned long best_center = 0;
for (unsigned long j = 0; j < centers.size(); ++j)
{
scalar_type dist = length(centers[j] - samples[i]);
if (dist < best_dist)
{
best_dist = dist;
best_center = j;
}
}
if (assignments[i] != best_center)
{
centers_changed = true;
assignments[i] = best_center;
}
center_element_count[best_center] += 1;
}
// now update all the centers
centers.assign(centers.size(), zero);
for (unsigned long i = 0; i < samples.size(); ++i)
{
centers[assignments[i]] += samples[i];
}
for (unsigned long i = 0; i < centers.size(); ++i)
{
if (center_element_count[i] != 0)
centers[i] /= center_element_count[i];
}
}
}
// ----------------------------------------------------------------------------------------
}
......
......@@ -235,6 +235,36 @@ namespace dlib
- #centers == a vector containing the candidate centers found
!*/
// ----------------------------------------------------------------------------------------
template <
typename vector_type,
typename sample_type,
typename alloc
>
void find_clusters_using_kmeans (
const vector_type& samples,
std::vector<sample_type, alloc>& centers,
unsigned long max_iter = 1000
);
/*!
requires
- samples.size() > 0
- samples == a bunch of row or column vectors and they all must be of the
same length.
- centers.size() > 0
- vector_type == something with an interface compatible with std::vector
and it must contain row or column vectors capable of being stored in
sample_type objects
- sample_type == a dlib::matrix capable of representing vectors
ensures
- performs regular old linear kmeans clustering on the samples. The clustering
begins with the initial set of centers given as an argument to this function.
When it finishes #centers will contain the resulting centers.
- no more than max_iter iterations will be performed before this function
terminates.
!*/
// ----------------------------------------------------------------------------------------
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment