Commit 41a87e59 authored by Branko Kokanovic's avatar Branko Kokanovic Committed by Davis E. King

Exposing chinese_whispers directly, closes #1642 (#1644)

* Exposing chinese_whispers directly

* Addressing comments

* Changed description
* Added support for distance weights
* Added tests

* Improving test to check returned results
parent 84b72278
......@@ -206,6 +206,36 @@ py::list chinese_whispers_clustering(py::list descriptors, float threshold)
return clusters;
}
py::list chinese_whispers_raw(py::list edges)
{
py::list clusters;
size_t num_edges = py::len(edges);
std::vector<sample_pair> edges_pairs;
std::vector<unsigned long> labels;
for (size_t idx = 0; idx < num_edges; ++idx)
{
py::tuple t = edges[idx].cast<py::tuple>();
if ((len(t) != 2) && (len(t) != 3))
{
PyErr_SetString( PyExc_IndexError, "Input must be a list of tuples with 2 or 3 elements.");
throw py::error_already_set();
}
size_t i = t[0].cast<size_t>();
size_t j = t[1].cast<size_t>();
double distance = (len(t) == 3) ? t[2].cast<double>(): 1;
edges_pairs.push_back(sample_pair(i, j, distance));
}
chinese_whispers(edges_pairs, labels);
for (size_t i = 0; i < labels.size(); ++i)
{
clusters.append(labels[i]);
}
return clusters;
}
void save_face_chips (
numpy_image<rgb_pixel> img,
const std::vector<full_object_detection>& faces,
......@@ -296,5 +326,10 @@ void bind_face_recognition(py::module &m)
m.def("chinese_whispers_clustering", &chinese_whispers_clustering, py::arg("descriptors"), py::arg("threshold"),
"Takes a list of descriptors and returns a list that contains a label for each descriptor. Clustering is done using dlib::chinese_whispers."
);
m.def("chinese_whispers", &chinese_whispers_raw, py::arg("edges"),
"Given a graph with vertices represented as numbers indexed from 0, this algorithm takes a list of edges and returns back a list that contains a labels (found clusters) for each vertex. "
"Edges are tuples with either 2 elements (integers presenting indexes of connected vertices) or 3 elements, where additional one element is float which presents distance weight of the edge). "
"Offers direct access to dlib::chinese_whispers."
);
}
from random import Random
from dlib import chinese_whispers
from pytest import raises
def test_chinese_whispers():
assert len(chinese_whispers([])) == 0
assert len(chinese_whispers([(0, 0), (1, 1)])) == 2
# Test that values from edges are actually used and that correct values are returned
labels = chinese_whispers([(0, 0), (0, 1), (1, 1)])
assert len(labels) == 2
assert labels[0] == labels[1]
labels = chinese_whispers([(0, 0), (1, 1)])
assert len(labels) == 2
assert labels[0] != labels[1]
def test_chinese_whispers_with_distance():
assert len(chinese_whispers([(0, 0, 1)])) == 1
assert len(chinese_whispers([(0, 0, 1), (0, 1, 0.5), (1, 1, 1)])) == 2
# Test that values from edges and distances are actually used and that correct values are returned
labels = chinese_whispers([(0, 0, 1), (0, 1, 1), (1, 1, 1)])
assert len(labels) == 2
assert labels[0] == labels[1]
labels = chinese_whispers([(0, 0, 1), (0, 1, 0.0), (1, 1, 1)])
assert len(labels) == 2
assert labels[0] != labels[1]
# Non-trivial test
edges = []
r = Random(0)
for i in range(100):
edges.append((i, i, 1))
edges.append((i, r.randint(0, 99), r.random()))
assert len(chinese_whispers(edges)) == 100
def test_chinese_whispers_type_checks():
"""
Tests contract (expected errors) in case client provides wrong types
"""
with raises(TypeError):
chinese_whispers()
with raises(TypeError):
chinese_whispers('foo')
with raises(RuntimeError):
chinese_whispers(['foo'])
with raises(IndexError):
chinese_whispers([(0,)])
with raises(IndexError):
chinese_whispers([(0, 1, 2, 3)])
with raises(RuntimeError):
chinese_whispers([('foo', 'bar')])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment