Commit aad01bec authored by Davis King's avatar Davis King

Added a max iteration argument to the kkmeans train function.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402318
parent d1457cd6
......@@ -80,13 +80,15 @@ namespace dlib
return centers.size();
}
template <typename matrix_type>
template <typename matrix_type, typename matrix_type2>
void train (
const matrix_type& samples,
const matrix_type& initial_centers
const matrix_type2& initial_centers,
long max_iter = 1000000
)
{
COMPILE_TIME_ASSERT((is_same_type<typename matrix_type::type, sample_type>::value));
COMPILE_TIME_ASSERT((is_same_type<typename matrix_type2::type, sample_type>::value));
// make sure requires clause is not broken
DLIB_ASSERT(samples.nc() == 1 && initial_centers.nc() == 1 &&
......@@ -111,8 +113,10 @@ namespace dlib
bool assignment_changed = true;
// loop until the centers stabilize
while (assignment_changed)
long count = 0;
while (assignment_changed && count < max_iter)
{
++count;
assignment_changed = false;
// loop over all the samples and assign them to their closest centers
......
......@@ -90,21 +90,26 @@ namespace dlib
!*/
template <
typename matrix_type
typename matrix_type,
typename matrix_type2
>
void train (
const matrix_type& samples,
const matrix_type& initial_centers
const matrix_type2& initial_centers,
long max_iter = 1000000
);
/*!
requires
- matrix_type::type == sample_type (i.e. matrix_type should contain sample_type objects)
- matrix_type2::type == sample_type (i.e. matrix_type2 should contain sample_type objects)
- initial_centers.nc() == 1 (i.e. must be a column vector)
- samples.nc() == 1 (i.e. must be a column vector)
- initial_centers.nr() == number_of_centers()
ensures
- performs k-means clustering of the given set of samples. The initial center points
are taken from the initial_centers argument.
- loops over the data and continues to refine the clustering until either the cluster centers
don't move or we have done max_iter iterations over the data.
- After this function finishes you can call the operator() function below
to determine which centroid a given sample is closest to.
!*/
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment