Commit 9e832ba3 authored by Davis King's avatar Davis King

- Added ordered_sample_pair

- Simplified some of the code a bit by making it use ordered_sample_pair
- Broke backwards compatibility, the syntax for passing order_by_distance
  and order_by_index to std::sort() is now slightly different since these
  functions are now templates.  This allows them to work on any kind of
  sample_pair or ordered_sample_pair object.
parent 0d26fec1
......@@ -10,6 +10,7 @@
#include "../rand.h"
#include <algorithm>
#include "sample_pair.h"
#include "ordered_sample_pair.h"
namespace dlib
{
......@@ -23,10 +24,11 @@ namespace dlib
vector_type& pairs
)
{
typedef typename vector_type::value_type T;
if (pairs.size() > 0)
{
// sort pairs so that we can avoid duplicates in the loop below
std::sort(pairs.begin(), pairs.end(), &order_by_index);
std::sort(pairs.begin(), pairs.end(), &order_by_index<T>);
// now put edges into temp while avoiding duplicates
vector_type temp;
......@@ -137,56 +139,13 @@ namespace dlib
remove_duplicate_edges(edges);
// now sort all the edges by distance and take the percent with the smallest distance
std::sort(edges.begin(), edges.end(), &order_by_distance);
std::sort(edges.begin(), edges.end(), &order_by_distance<sample_pair>);
const unsigned long out_size = std::min<unsigned long>((unsigned long)(num*percent), edges.size());
out.assign(edges.begin(), edges.begin() + out_size);
}
}
// ----------------------------------------------------------------------------------------
namespace impl2
{
struct helper
{
/*
This is like the sample_pair but lets the edges be directional
*/
helper(
unsigned long idx1,
unsigned long idx2,
double dist
) :
index1(idx1),
index2(idx2),
distance(dist)
{}
unsigned long index1;
unsigned long index2;
double distance;
};
inline bool order_by_index (
const helper& a,
const helper& b
)
{
return a.index1 < b.index1 || (a.index1 == b.index1 && a.index2 < b.index2);
}
inline bool total_order_by_distance (
const helper& a,
const helper& b
)
{
return a.distance < b.distance || (a.distance == b.distance && order_by_index(a,b));
}
}
// ----------------------------------------------------------------------------------------
template <
......@@ -223,7 +182,7 @@ namespace dlib
// we add each edge twice in the following loop. So multiply num by 2 to account for that.
num *= 2;
std::vector<impl2::helper> edges;
std::vector<ordered_sample_pair> edges;
edges.reserve(num);
std::vector<sample_pair, alloc> temp;
temp.reserve(num);
......@@ -241,45 +200,45 @@ namespace dlib
const double dist = dist_funct(samples[idx1], samples[idx2]);
if (dist < std::numeric_limits<double>::infinity())
{
edges.push_back(impl2::helper(idx1, idx2, dist));
edges.push_back(impl2::helper(idx2, idx1, dist));
edges.push_back(ordered_sample_pair(idx1, idx2, dist));
edges.push_back(ordered_sample_pair(idx2, idx1, dist));
}
}
}
std::sort(edges.begin(), edges.end(), &impl2::order_by_index);
std::sort(edges.begin(), edges.end(), &order_by_index<ordered_sample_pair>);
std::vector<impl2::helper>::iterator beg, itr;
std::vector<ordered_sample_pair>::iterator beg, itr;
// now copy edges into temp when they aren't duplicates and also only move in the k shortest for
// each index.
itr = edges.begin();
while (itr != edges.end())
{
// first find the bounding range for all the edges connected to node itr->index1
// first find the bounding range for all the edges connected to node itr->index1()
beg = itr;
while (itr != edges.end() && itr->index1 == beg->index1)
while (itr != edges.end() && itr->index1() == beg->index1())
++itr;
// If the node has more than k edges then sort them by distance so that
// we will end up with the k best.
if (static_cast<unsigned long>(itr - beg) > k)
{
std::sort(beg, itr, &impl2::total_order_by_distance);
std::sort(beg, itr, &order_by_distance_and_index<ordered_sample_pair>);
}
// take the k best unique edges from the range [beg,itr)
temp.push_back(sample_pair(beg->index1, beg->index2, beg->distance));
unsigned long prev_index2 = beg->index2;
temp.push_back(sample_pair(beg->index1(), beg->index2(), beg->distance()));
unsigned long prev_index2 = beg->index2();
++beg;
unsigned long count = 1;
for (; beg != itr && count < k; ++beg)
{
if (beg->index2 != prev_index2)
if (beg->index2() != prev_index2)
{
temp.push_back(sample_pair(beg->index1, beg->index2, beg->distance));
temp.push_back(sample_pair(beg->index1(), beg->index2(), beg->distance()));
++count;
}
prev_index2 = beg->index2;
prev_index2 = beg->index2();
}
}
......@@ -365,7 +324,7 @@ namespace dlib
}
// sort the edges so that duplicate edges will be adjacent
std::sort(edges.begin(), edges.end(), &order_by_index);
std::sort(edges.begin(), edges.end(), &order_by_index<sample_pair>);
// if the first edge is valid
if (edges[0].index1() < samples.size())
......@@ -399,8 +358,9 @@ namespace dlib
const vector_type& pairs
)
{
typedef typename vector_type::value_type T;
vector_type temp(pairs);
std::sort(temp.begin(), temp.end(), &order_by_index);
std::sort(temp.begin(), temp.end(), &order_by_index<T>);
for (unsigned long i = 1; i < temp.size(); ++i)
{
......@@ -417,7 +377,9 @@ namespace dlib
template <
typename vector_type
>
typename enable_if<is_same_type<sample_pair, typename vector_type::value_type>,unsigned long>::type
typename enable_if_c<(is_same_type<sample_pair, typename vector_type::value_type>::value ||
is_same_type<ordered_sample_pair, typename vector_type::value_type>::value),
unsigned long>::type
max_index_plus_one (
const vector_type& pairs
)
......@@ -506,7 +468,8 @@ namespace dlib
<< "\n\t percent: " << percent
);
std::sort(pairs.begin(), pairs.end(), &order_by_distance);
typedef typename vector_type::value_type T;
std::sort(pairs.begin(), pairs.end(), &order_by_distance<T>);
const unsigned long num = static_cast<unsigned long>((1.0-percent)*pairs.size());
......@@ -534,7 +497,8 @@ namespace dlib
<< "\n\t percent: " << percent
);
std::sort(pairs.rbegin(), pairs.rend(), &order_by_distance);
typedef typename vector_type::value_type T;
std::sort(pairs.rbegin(), pairs.rend(), &order_by_distance<T>);
const unsigned long num = static_cast<unsigned long>((1.0-percent)*pairs.size());
......
......@@ -132,8 +132,9 @@ namespace dlib
);
/*!
requires
- vector_type == a type with an interface compatible with std::vector and
it must in turn contain objects with an interface compatible with dlib::sample_pair
- vector_type == a type with an interface compatible with std::vector and it
must in turn contain objects with an interface compatible with
dlib::sample_pair or dlib::ordered_sample_pair.
ensures
- if (pairs contains any elements that are equal according to operator==) then
- returns true
......@@ -151,8 +152,9 @@ namespace dlib
);
/*!
requires
- vector_type == a type with an interface compatible with std::vector and
it must in turn contain objects with an interface compatible with dlib::sample_pair
- vector_type == a type with an interface compatible with std::vector and it
must in turn contain objects with an interface compatible with
dlib::sample_pair or dlib::ordered_sample_pair.
ensures
- if (pairs.size() == 0) then
- returns 0
......@@ -173,8 +175,9 @@ namespace dlib
);
/*!
requires
- vector_type == a type with an interface compatible with std::vector and
it must in turn contain objects with an interface compatible with dlib::sample_pair
- vector_type == a type with an interface compatible with std::vector and it
must in turn contain objects with an interface compatible with
dlib::sample_pair or dlib::ordered_sample_pair.
ensures
- Removes all elements of pairs that have a distance value greater than the
given threshold.
......@@ -192,8 +195,9 @@ namespace dlib
);
/*!
requires
- vector_type == a type with an interface compatible with std::vector and
it must in turn contain objects with an interface compatible with dlib::sample_pair
- vector_type == a type with an interface compatible with std::vector and it
must in turn contain objects with an interface compatible with
dlib::sample_pair or dlib::ordered_sample_pair.
ensures
- Removes all elements of pairs that have a distance value less than the
given threshold.
......@@ -212,8 +216,9 @@ namespace dlib
/*!
requires
- 0 <= percent < 1
- vector_type == a type with an interface compatible with std::vector and
it must in turn contain objects with an interface compatible with dlib::sample_pair
- vector_type == a type with an interface compatible with std::vector and it
must in turn contain objects with an interface compatible with
dlib::sample_pair or dlib::ordered_sample_pair.
ensures
- Removes the given upper percentage of the longest edges in pairs. I.e.
this function removes the long edges from pairs.
......@@ -232,8 +237,9 @@ namespace dlib
/*!
requires
- 0 <= percent < 1
- vector_type == a type with an interface compatible with std::vector and
it must in turn contain objects with an interface compatible with dlib::sample_pair
- vector_type == a type with an interface compatible with std::vector and it
must in turn contain objects with an interface compatible with
dlib::sample_pair or dlib::ordered_sample_pair.
ensures
- Removes the given upper percentage of the shortest edges in pairs. I.e.
this function removes the short edges from pairs.
......@@ -250,8 +256,9 @@ namespace dlib
);
/*!
requires
- vector_type == a type with an interface compatible with std::vector and
it must in turn contain objects with an interface compatible with dlib::sample_pair
- vector_type == a type with an interface compatible with std::vector and it
must in turn contain objects with an interface compatible with
dlib::sample_pair or dlib::ordered_sample_pair.
ensures
- Removes any duplicate edges from pairs. That is, for all elements of pairs,
A and B, such that A == B, only one of A or B will be in pairs after this
......
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_ORDERED_SAMPLE_PaIR_H__
#define DLIB_ORDERED_SAMPLE_PaIR_H__
#include "ordered_sample_pair_abstract.h"
#include <limits>
#include "../serialize.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
class ordered_sample_pair
{
public:
ordered_sample_pair(
) :
_index1(0),
_index2(0)
{
_distance = std::numeric_limits<double>::infinity();
}
ordered_sample_pair (
const unsigned long idx1,
const unsigned long idx2,
const double dist
)
{
_distance = dist;
_index1 = idx1;
_index2 = idx2;
}
const unsigned long& index1 (
) const { return _index1; }
const unsigned long& index2 (
) const { return _index2; }
const double& distance (
) const { return _distance; }
private:
unsigned long _index1;
unsigned long _index2;
double _distance;
};
// ----------------------------------------------------------------------------------------
inline bool operator == (
const ordered_sample_pair& a,
const ordered_sample_pair& b
)
{
return a.index1() == b.index1() && a.index2() == b.index2();
}
inline bool operator != (
const ordered_sample_pair& a,
const ordered_sample_pair& b
)
{
return !(a == b);
}
// ----------------------------------------------------------------------------------------
inline void serialize (
const ordered_sample_pair& item,
std::ostream& out
)
{
try
{
serialize(item.index1(),out);
serialize(item.index2(),out);
serialize(item.distance(),out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type ordered_sample_pair");
}
}
inline void deserialize (
ordered_sample_pair& item,
std::istream& in
)
{
try
{
unsigned long idx1, idx2;
double dist;
deserialize(idx1,in);
deserialize(idx2,in);
deserialize(dist,in);
item = ordered_sample_pair(idx1, idx2, dist);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing object of type ordered_sample_pair");
}
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_ORDERED_SAMPLE_PaIR_H__
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_ORDERED_SAMPLE_PaIR_ABSTRACT_H__
#ifdef DLIB_ORDERED_SAMPLE_PaIR_ABSTRACT_H__
#include <limits>
#include "../serialize.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
class ordered_sample_pair
{
/*!
WHAT THIS OBJECT REPRESENTS
This object is intended to represent an edge in a directed graph which has
data samples at its vertices. So it contains two integers (index1 and
index2) which represent the identifying indices of the samples at the ends
of an edge.
This object also contains a double which can be used for any purpose.
!*/
public:
ordered_sample_pair(
);
/*!
ensures
- #index1() == 0
- #index2() == 0
- #distance() == std::numeric_limits<double>::infinity()
!*/
ordered_sample_pair (
const unsigned long idx1,
const unsigned long idx2,
const double dist
);
/*!
ensures
- #index1() == idx1
- #index2() == idx2
- #distance() == dist
!*/
const unsigned long& index1 (
) const;
/*!
ensures
- returns the first index value stored in this object
!*/
const unsigned long& index2 (
) const;
/*!
ensures
- returns the second index value stored in this object
!*/
const double& distance (
) const;
/*!
ensures
- returns the floating point number stored in this object
!*/
};
// ----------------------------------------------------------------------------------------
inline bool operator == (
const ordered_sample_pair& a,
const ordered_sample_pair& b
);
/*!
ensures
- returns a.index1() == b.index1() && a.index2() == b.index2();
I.e. returns true if a and b both represent the same pair and false otherwise.
Note that the distance field is not involved in this comparison.
!*/
inline bool operator != (
const ordered_sample_pair& a,
const ordered_sample_pair& b
);
/*!
ensures
- returns !(a == b)
!*/
// ----------------------------------------------------------------------------------------
inline void serialize (
const ordered_sample_pair& item,
std::ostream& out
);
/*!
provides serialization support
!*/
inline void deserialize (
ordered_sample_pair& item,
std::istream& in
);
/*!
provides deserialization support
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_ORDERED_SAMPLE_PaIR_ABSTRACT_H__
......@@ -59,22 +59,33 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <typename T>
inline bool order_by_index (
const sample_pair& a,
const sample_pair& b
const T& a,
const T& b
)
{
return a.index1() < b.index1() || (a.index1() == b.index1() && a.index2() < b.index2());
}
template <typename T>
inline bool order_by_distance (
const sample_pair& a,
const sample_pair& b
const T& a,
const T& b
)
{
return a.distance() < b.distance();
}
template <typename T>
bool order_by_distance_and_index (
const T& a,
const T& b
)
{
return a.distance() < b.distance() || (a.distance() == b.distance() && order_by_index(a,b));
}
// ----------------------------------------------------------------------------------------
inline bool operator == (
......
......@@ -74,28 +74,49 @@ namespace dlib
// ----------------------------------------------------------------------------------------
inline bool order_by_index (
const sample_pair& a,
const sample_pair& b
template <typename T>
bool order_by_index (
const T& a,
const T& b
) { return a.index1() < b.index1() || (a.index1() == b.index1() && a.index2() < b.index2()); }
/*!
requires
- T is a type with an interface compatible with sample_pair.
ensures
- provides a total ordering of sample_pair objects that will cause pairs that are
equal to be adjacent when sorted. This function can be used with std::sort() to
first sort sequences of sample_pair objects and then find duplicate edges.
!*/
inline bool order_by_distance (
const sample_pair& a,
const sample_pair& b
template <typename T>
bool order_by_distance (
const T& a,
const T& b
) { return a.distance() < b.distance(); }
/*!
requires
- T is a type with an interface compatible with sample_pair.
ensures
- provides a total ordering of sample_pair objects that causes pairs with
smallest distance to be the first in a sorted list. This function can be
used with std::sort()
!*/
template <typename T>
bool order_by_distance_and_index (
const T& a,
const T& b
) { return a.distance() < b.distance() || (a.distance() == b.distance() && order_by_index(a,b)); }
/*!
requires
- T is a type with an interface compatible with sample_pair.
ensures
- provides a total ordering of sample_pair objects that causes pairs with
smallest distance to be the first in a sorted list but also orders samples
with equal distances according to order_by_index(). This function can be
used with std::sort()
!*/
// ----------------------------------------------------------------------------------------
inline bool operator == (
......
......@@ -195,7 +195,7 @@ namespace
find_k_nearest_neighbors(samples, squared_euclidean_distance(), 1, edges);
DLIB_TEST(edges.size() == 4);
std::sort(edges.begin(), edges.end(), &order_by_index);
std::sort(edges.begin(), edges.end(), &order_by_index<sample_pair>);
DLIB_TEST(edges[0] == sample_pair(0,1,0));
DLIB_TEST(edges[1] == sample_pair(0,2,0));
......@@ -208,7 +208,7 @@ namespace
find_k_nearest_neighbors(samples, squared_euclidean_distance(3.9, 4.1), 3, edges);
DLIB_TEST(edges.size() == 4);
std::sort(edges.begin(), edges.end(), &order_by_index);
std::sort(edges.begin(), edges.end(), &order_by_index<sample_pair>);
DLIB_TEST(edges[0] == sample_pair(1,2,0));
DLIB_TEST(edges[1] == sample_pair(1,3,0));
......@@ -235,7 +235,7 @@ namespace
find_approximate_k_nearest_neighbors(samples, squared_euclidean_distance(), 1, 10000, seed, edges);
DLIB_TEST(edges.size() == 4);
std::sort(edges.begin(), edges.end(), &order_by_index);
std::sort(edges.begin(), edges.end(), &order_by_index<sample_pair>);
DLIB_TEST(edges[0] == sample_pair(0,1,0));
DLIB_TEST(edges[1] == sample_pair(0,2,0));
......@@ -248,7 +248,7 @@ namespace
find_approximate_k_nearest_neighbors(samples, squared_euclidean_distance(3.9, 4.1), 3, 10000, seed, edges);
DLIB_TEST(edges.size() == 4);
std::sort(edges.begin(), edges.end(), &order_by_index);
std::sort(edges.begin(), edges.end(), &order_by_index<sample_pair>);
DLIB_TEST(edges[0] == sample_pair(1,2,0));
DLIB_TEST(edges[1] == sample_pair(1,3,0));
......@@ -274,7 +274,7 @@ namespace
find_k_nearest_neighbors(samples, squared_euclidean_distance(), 2, edges);
DLIB_TEST(edges.size() == 4);
std::sort(edges.begin(), edges.end(), &order_by_index);
std::sort(edges.begin(), edges.end(), &order_by_index<sample_pair>);
DLIB_TEST(edges[0] == sample_pair(0,1,0));
DLIB_TEST(edges[1] == sample_pair(0,2,0));
......@@ -302,7 +302,7 @@ namespace
find_approximate_k_nearest_neighbors(samples, squared_euclidean_distance(), 2, 10000, seed, edges);
DLIB_TEST(edges.size() == 4);
std::sort(edges.begin(), edges.end(), &order_by_index);
std::sort(edges.begin(), edges.end(), &order_by_index<sample_pair>);
DLIB_TEST(edges[0] == sample_pair(0,1,0));
DLIB_TEST(edges[1] == sample_pair(0,2,0));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment