Commit 84c15955 authored by Davis King's avatar Davis King

Made it so the kkmeans train function can take any kind of

vector container.  Not just dlib::matrix.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402375
parent bd863f3f
...@@ -96,8 +96,64 @@ namespace dlib ...@@ -96,8 +96,64 @@ namespace dlib
return centers.size(); return centers.size();
} }
template <typename matrix_type, typename matrix_type2> template <typename T, typename U>
void train ( void train (
const T& samples,
const U& initial_centers,
long max_iter = 1000000
)
{
do_train(vector_to_matrix(samples),vector_to_matrix(initial_centers),max_iter);
}
unsigned long operator() (
const sample_type& sample
) const
{
unsigned long label = 0;
scalar_type best_score = (*centers[0])(sample);
// figure out which center the given sample is closest too
for (unsigned long i = 1; i < centers.size(); ++i)
{
scalar_type temp = (*centers[i])(sample);
if (temp < best_score)
{
label = i;
best_score = temp;
}
}
return label;
}
void swap (
kkmeans& item
)
{
centers.swap(item.centers);
kc.swap(item.kc);
assignments.swap(item.assignments);
}
friend void serialize(const kkmeans& item, std::ostream& out)
{
serialize(item.centers, out);
serialize(item.kc, out);
serialize(item.assignments, out);
}
friend void deserialize(kkmeans& item, std::istream& in)
{
deserialize(item.centers, in);
deserialize(item.kc, in);
deserialize(item.assignments, in);
}
private:
template <typename matrix_type, typename matrix_type2>
void do_train (
const matrix_type& samples, const matrix_type& samples,
const matrix_type2& initial_centers, const matrix_type2& initial_centers,
long max_iter = 1000000 long max_iter = 1000000
...@@ -175,52 +231,6 @@ namespace dlib ...@@ -175,52 +231,6 @@ namespace dlib
} }
unsigned long operator() (
const sample_type& sample
) const
{
unsigned long label = 0;
scalar_type best_score = (*centers[0])(sample);
// figure out which center the given sample is closest too
for (unsigned long i = 1; i < centers.size(); ++i)
{
scalar_type temp = (*centers[i])(sample);
if (temp < best_score)
{
label = i;
best_score = temp;
}
}
return label;
}
void swap (
kkmeans& item
)
{
centers.swap(item.centers);
kc.swap(item.kc);
assignments.swap(item.assignments);
}
friend void serialize(const kkmeans& item, std::ostream& out)
{
serialize(item.centers, out);
serialize(item.kc, out);
serialize(item.assignments, out);
}
friend void deserialize(kkmeans& item, std::istream& in)
{
deserialize(item.centers, in);
deserialize(item.kc, in);
deserialize(item.assignments, in);
}
private:
typename array<scoped_ptr<kcentroid<kernel_type> > >::expand_1b_c centers; typename array<scoped_ptr<kcentroid<kernel_type> > >::expand_1b_c centers;
kcentroid<kernel_type> kc; kcentroid<kernel_type> kc;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment