Commit 68fb08b1 authored by Davis King's avatar Davis King

Added segment_number_line().

parent 93f3a09f
......@@ -130,6 +130,121 @@ namespace dlib
return relabel.size();
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
struct snl_range
{
snl_range() = default;
snl_range(double val) : lower(val), upper(val) {}
snl_range(double l, double u) : lower(l), upper(u) { DLIB_ASSERT(lower <= upper)}
double lower = 0;
double upper = 0;
double width() const { return upper-lower; }
bool operator<(const snl_range& item) const { return lower < item.lower; }
};
inline snl_range merge(const snl_range& a, const snl_range& b)
{
return snl_range(std::min(a.lower, b.lower), std::max(a.upper, b.upper));
}
inline double distance (const snl_range& a, const snl_range& b)
{
return std::max(a.lower,b.lower) - std::min(a.upper,b.upper);
}
inline std::ostream& operator<< (std::ostream& out, const snl_range& item )
{
out << "["<<item.lower<<","<<item.upper<<"]";
return out;
}
// ----------------------------------------------------------------------------------------
inline std::vector<snl_range> segment_number_line (
const std::vector<double>& x,
const double max_range_width
)
{
DLIB_CASSERT(max_range_width >= 0);
// create initial ranges, one for each value in x. So initially, all the ranges have
// width of 0.
std::vector<snl_range> ranges;
for (auto v : x)
ranges.push_back(v);
std::sort(ranges.begin(), ranges.end());
std::vector<snl_range> greedy_final_ranges;
if (ranges.size() == 0)
return greedy_final_ranges;
// We will try two different clustering strategies. One that does a simple greedy left
// to right sweep and another that does a bottom up agglomerative clustering. This
// first loop runs the greedy left to right sweep. Then at the end of this routine we
// will return the results that produced the tightest clustering.
greedy_final_ranges.push_back(ranges[0]);
for (size_t i = 1; i < ranges.size(); ++i)
{
auto m = merge(greedy_final_ranges.back(), ranges[i]);
if (m.width() <= max_range_width)
greedy_final_ranges.back() = m;
else
greedy_final_ranges.push_back(ranges[i]);
}
// Here we do the bottom up clustering. So compute the edges connecting our ranges.
// We will simply say there are edges between ranges if and only if they are
// immediately adjacent on the number line.
std::vector<sample_pair> edges;
for (size_t i = 1; i < ranges.size(); ++i)
edges.push_back(sample_pair(i-1,i, distance(ranges[i-1],ranges[i])));
std::sort(edges.begin(), edges.end(), order_by_distance<sample_pair>);
disjoint_subsets sets;
sets.set_size(ranges.size());
// Now start merging nodes.
for (auto edge : edges)
{
// find the next best thing to merge.
unsigned long a = sets.find_set(edge.index1());
unsigned long b = sets.find_set(edge.index2());
// merge it if it doesn't result in an interval that's too big.
auto m = merge(ranges[a], ranges[b]);
if (m.width() <= max_range_width)
{
unsigned long news = sets.merge_sets(a,b);
ranges[news] = m;
}
}
// Now create a list of the final ranges. We will do this by keeping track of which
// range we already added to final_ranges.
std::vector<snl_range> final_ranges;
std::vector<bool> already_output(ranges.size(), false);
for (unsigned long i = 0; i < sets.size(); ++i)
{
auto s = sets.find_set(i);
if (!already_output[s])
{
final_ranges.push_back(ranges[s]);
already_output[s] = true;
}
}
// only use the greedy clusters if they found a clustering with fewer clusters.
// Otherwise, the bottom up clustering probably produced a more sensible clustering.
if (final_ranges.size() <= greedy_final_ranges.size())
return final_ranges;
else
return greedy_final_ranges;
}
// ----------------------------------------------------------------------------------------
}
......
......@@ -42,6 +42,90 @@ namespace dlib
(i.e. cluster IDs are assigned contiguously and start at 0)
!*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
struct snl_range
{
/*!
WHAT THIS OBJECT REPRESENTS
This object represents an interval on the real number line. It is used
to store the outputs of the segment_number_line() routine defined below.
!*/
snl_range(
);
/*!
ensures
- #lower == 0
- #upper == 0
!*/
snl_range(
double val
);
/*!
ensures
- #lower == val
- #upper == val
!*/
snl_range(
double l,
double u
);
/*!
requires
- l <= u
ensures
- #lower == l
- #upper == u
!*/
double lower;
double upper;
double width(
) const { return upper-lower; }
/*!
ensures
- returns the width of this interval on the number line.
!*/
bool operator<(const snl_range& item) const { return lower < item.lower; }
/*!
ensures
- provides a total ordering of snl_range objects assuming they are
non-overlapping.
!*/
};
std::ostream& operator<< (std::ostream& out, const snl_range& item );
/*!
ensures
- prints item to out in the form [lower,upper].
!*/
// ----------------------------------------------------------------------------------------
std::vector<snl_range> segment_number_line (
const std::vector<double>& x,
const double max_range_width
);
/*!
requires
- max_range_width >= 0
ensures
- Finds a clustering of the values in x and returns the ranges that define the
clustering. This routine uses a combination of bottom up clustering and a
simple greedy scan to try and find the most compact set of ranges that
contain all the values in x.
- Every value in x will be contained inside one of the returned snl_range
objects;
- All returned snl_range object's will have a width() <= max_range_width and
will also be non-overlapping.
!*/
// ----------------------------------------------------------------------------------------
}
......
......@@ -303,6 +303,43 @@ namespace
DLIB_TEST(labels[1] == 1);
}
void test_segment_number_line()
{
dlib::rand rnd;
std::vector<double> x;
for (int i = 0; i < 5000; ++i)
{
x.push_back(rnd.get_double_in_range(-1.5, -1.01));
x.push_back(rnd.get_double_in_range(-0.99, -0.01));
x.push_back(rnd.get_double_in_range(0.01, 1));
}
auto r = segment_number_line(x,1);
std::sort(r.begin(), r.end());
DLIB_TEST(r.size() == 3);
DLIB_TEST(-1.5 <= r[0].lower && r[0].lower < r[0].upper && r[0].upper <= -1.01);
DLIB_TEST(-0.99 <= r[1].lower && r[1].lower < r[1].upper && r[1].upper <= -0.01);
DLIB_TEST(0.01 <= r[2].lower && r[2].lower < r[2].upper && r[2].upper <= 1);
x.clear();
for (int i = 0; i < 5000; ++i)
{
x.push_back(rnd.get_double_in_range(-2, 1));
x.push_back(rnd.get_double_in_range(-2, 1));
x.push_back(rnd.get_double_in_range(-2, 1));
}
r = segment_number_line(x,1);
DLIB_TEST(r.size() == 3);
r = segment_number_line(x,1.5);
DLIB_TEST(r.size() == 2);
r = segment_number_line(x,10.5);
DLIB_TEST(r.size() == 1);
DLIB_TEST(-2 <= r[0].lower && r[0].lower < r[0].upper && r[0].upper <= 1);
}
class test_clustering : public tester
{
public:
......@@ -316,6 +353,7 @@ namespace
)
{
test_bottom_up_clustering();
test_segment_number_line();
dlib::rand rnd;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment