Commit 5d1aad03 authored by Davis King's avatar Davis King

Cleaned up the kkmeans class and made it actually use

the min_change parameter.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402510
parent 290494bb
...@@ -101,7 +101,7 @@ namespace dlib ...@@ -101,7 +101,7 @@ namespace dlib
void train ( void train (
const T& samples, const T& samples,
const U& initial_centers, const U& initial_centers,
long max_iter = 1000000 long max_iter = 1000
) )
{ {
do_train(vector_to_matrix(samples),vector_to_matrix(initial_centers),max_iter); do_train(vector_to_matrix(samples),vector_to_matrix(initial_centers),max_iter);
...@@ -132,6 +132,13 @@ namespace dlib ...@@ -132,6 +132,13 @@ namespace dlib
scalar_type min_change_ scalar_type min_change_
) )
{ {
// make sure requires clause is not broken
DLIB_ASSERT( 0 <= min_change_ < 1,
"\tvoid kkmeans::set_min_change()"
<< "\n\tInvalid arguments to this function"
<< "\n\tthis: " << this
<< "\n\tmin_change_: " << min_change_
);
min_change = min_change_; min_change = min_change_;
} }
...@@ -184,7 +191,7 @@ namespace dlib ...@@ -184,7 +191,7 @@ namespace dlib
void do_train ( void do_train (
const matrix_type& samples, const matrix_type& samples,
const matrix_type2& initial_centers, const matrix_type2& initial_centers,
long max_iter = 1000000 long max_iter = 1000
) )
{ {
COMPILE_TIME_ASSERT((is_same_type<typename matrix_type::type, sample_type>::value)); COMPILE_TIME_ASSERT((is_same_type<typename matrix_type::type, sample_type>::value));
...@@ -214,10 +221,13 @@ namespace dlib ...@@ -214,10 +221,13 @@ namespace dlib
// loop until the centers stabilize // loop until the centers stabilize
long count = 0; long count = 0;
while (assignment_changed && count < max_iter) const unsigned long min_num_change = static_cast<unsigned long>(min_change*samples.size());
unsigned long num_changed = min_num_change;
while (assignment_changed && count < max_iter && num_changed >= min_num_change)
{ {
++count; ++count;
assignment_changed = false; assignment_changed = false;
num_changed = 0;
// loop over all the samples and assign them to their closest centers // loop over all the samples and assign them to their closest centers
for (long i = 0; i < samples.size(); ++i) for (long i = 0; i < samples.size(); ++i)
...@@ -240,6 +250,7 @@ namespace dlib ...@@ -240,6 +250,7 @@ namespace dlib
{ {
assignments[i] = best_center; assignments[i] = best_center;
assignment_changed = true; assignment_changed = true;
++num_changed;
} }
} }
......
...@@ -111,7 +111,7 @@ namespace dlib ...@@ -111,7 +111,7 @@ namespace dlib
); );
/*! /*!
requires requires
- matrix_type and matrix_type2 must either be dlib::matrix objects or convertable to dlib::matrix - matrix_type and matrix_type2 must either be dlib::matrix objects or convertible to dlib::matrix
via vector_to_matrix() via vector_to_matrix()
- matrix_type::type == sample_type (i.e. matrix_type should contain sample_type objects) - matrix_type::type == sample_type (i.e. matrix_type should contain sample_type objects)
- matrix_type2::type == sample_type (i.e. matrix_type2 should contain sample_type objects) - matrix_type2::type == sample_type (i.e. matrix_type2 should contain sample_type objects)
...@@ -122,8 +122,8 @@ namespace dlib ...@@ -122,8 +122,8 @@ namespace dlib
- performs k-means clustering of the given set of samples. The initial center points - performs k-means clustering of the given set of samples. The initial center points
are taken from the initial_centers argument. are taken from the initial_centers argument.
- loops over the data and continues to refine the clustering until either less than - loops over the data and continues to refine the clustering until either less than
get_min_change() fraction of the cluster centers move or we have done max_iter iterations get_min_change() fraction of the data points change clusters or we have done max_iter
over the data. iterations over the data.
- After this function finishes you can call the operator() function below - After this function finishes you can call the operator() function below
to determine which centroid a given sample is closest to. to determine which centroid a given sample is closest to.
!*/ !*/
...@@ -153,8 +153,8 @@ namespace dlib ...@@ -153,8 +153,8 @@ namespace dlib
) const; ) const;
/*! /*!
ensures ensures
- returns the minimum fraction of centers that need to change - returns the minimum fraction of data points that need to change
in an iteration of kmeans for the algorithm to keep going. centers in an iteration of kmeans for the algorithm to keep going.
!*/ !*/
void swap ( void swap (
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment