You need to sign in or sign up before continuing.
Commit ea02f4d4 authored by Davis King's avatar Davis King

Added unit tests for the new graph cuts tools.

parent 80e501d8
...@@ -45,6 +45,7 @@ set (tests ...@@ -45,6 +45,7 @@ set (tests
find_max_factor_graph_viterbi.cpp find_max_factor_graph_viterbi.cpp
geometry.cpp geometry.cpp
graph.cpp graph.cpp
graph_cuts.cpp
hash.cpp hash.cpp
hash_map.cpp hash_map.cpp
hash_set.cpp hash_set.cpp
......
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include <dlib/graph_cuts.h>
#include <dlib/graph_utils.h>
#include <dlib/directed_graph.h>
#include <dlib/rand.h>
#include "tester.h"
namespace
{
using namespace test;
using namespace dlib;
using namespace std;
logger dlog("test.graph_cuts");
// ----------------------------------------------------------------------------------------
class dense_potts_problem
{
public:
typedef double value_type;
private:
matrix<value_type,0,2> factors1;
matrix<value_type> factors2;
matrix<node_label,0,1> labels;
public:
dense_potts_problem (
unsigned long num_nodes,
dlib::rand& rnd
)
{
factors1 = -7*(randm(num_nodes, 2, rnd)-0.5);
factors2 = make_symmetric(randm(num_nodes, num_nodes, rnd) > 0.5);
labels.set_size(num_nodes);
labels = FREE_NODE;
}
unsigned long number_of_nodes (
) const { return factors1.nr(); }
unsigned long number_of_neighbors (
unsigned long // idx
) const { return number_of_nodes()-1; }
unsigned long get_neighbor_idx (
unsigned long node_id1,
unsigned long node_id2
) const
{
if (node_id2 < node_id1)
return node_id2;
else
return node_id2-1;
}
unsigned long get_neighbor (
unsigned long node_id,
unsigned long idx
) const
{
DLIB_TEST(node_id < number_of_nodes());
DLIB_TEST(idx < number_of_neighbors(node_id));
if (idx < node_id)
return idx;
else
return idx+1;
}
void set_label (
const unsigned long& idx,
node_label value
)
{
labels(idx) = value;
}
node_label get_label (
const unsigned long& idx
) const
{
return labels(idx);
}
value_type factor_value (unsigned long idx, bool value) const
{
DLIB_TEST(idx < number_of_nodes());
if (value)
return factors1(idx,0);
else
return factors1(idx,1);
}
value_type factor_value_disagreement (unsigned long idx1, unsigned long idx2) const
{
DLIB_TEST(idx1 != idx2);
DLIB_TEST(idx1 < number_of_nodes());
DLIB_TEST(idx2 < number_of_nodes());
DLIB_TEST(get_neighbor_idx(idx1,idx2) < number_of_neighbors(idx1));
DLIB_TEST(get_neighbor_idx(idx2,idx1) < number_of_neighbors(idx2));
return factors2(idx1, idx2);
}
};
// ----------------------------------------------------------------------------------------
class image_potts_problem
{
public:
typedef double value_type;
const static unsigned long max_number_of_neighbors = 4;
private:
matrix<value_type,0,2> factors1;
matrix<value_type> factors2;
matrix<node_label,0,1> labels;
long nr;
long nc;
rectangle rect, inner_rect;
mutable long count;
public:
image_potts_problem (
long nr_,
long nc_,
dlib::rand& rnd
) : nr(nr_), nc(nc_)
{
rect = rectangle(0,0,nc-1,nr-1);
inner_rect = shrink_rect(rect,1);
const unsigned long num_nodes = nr*nc;
factors1 = -7*(randm(num_nodes, 2, rnd));
factors2 = randm(num_nodes, 4, rnd) > 0.5;
//factors1 = 0;
//set_rowm(factors1, range(0, factors1.nr()/2)) = -1;
labels.set_size(num_nodes);
labels = FREE_NODE;
count = 0;
}
~image_potts_problem()
{
dlog << LTRACE << "interface calls: " << count;
dlog << LTRACE << "labels hash: "<< murmur_hash3_128bit(&labels(0), labels.size()*sizeof(labels(0)), 0).first;
}
unsigned long number_of_nodes (
) const { return factors1.nr(); }
unsigned long number_of_neighbors (
unsigned long idx
) const
{
++count;
const point& p = get_loc(idx);
if (inner_rect.contains(p))
return 4;
else if (p == rect.tl_corner() ||
p == rect.bl_corner() ||
p == rect.tr_corner() ||
p == rect.br_corner() )
return 2;
else
return 3;
}
unsigned long get_neighbor_idx (
long node_id1,
long node_id2
) const
{
++count;
const point& p = get_loc(node_id1);
long ret = 0;
if (rect.contains(p + point(1,0)))
{
if (node_id2-node_id1 == 1)
return ret;
++ret;
}
if (rect.contains(p - point(1,0)))
{
if (node_id2-node_id1 == -1)
return ret;
++ret;
}
if (rect.contains(p + point(0,1)))
{
if (node_id2-node_id1 == nc)
return ret;
++ret;
}
return ret;
}
unsigned long get_neighbor (
long node_id,
long idx
) const
{
++count;
const point& p = get_loc(node_id);
if (rect.contains(p + point(1,0)))
{
if (idx == 0)
return node_id+1;
--idx;
}
if (rect.contains(p - point(1,0)))
{
if (idx == 0)
return node_id-1;
--idx;
}
if (rect.contains(p + point(0,1)))
{
if (idx == 0)
return node_id+nc;
--idx;
}
return node_id-nc;
}
void set_label (
const unsigned long& idx,
node_label value
)
{
++count;
labels(idx) = value;
}
node_label get_label (
const unsigned long& idx
) const
{
++count;
return labels(idx);
}
value_type factor_value (unsigned long idx, bool value) const
{
++count;
DLIB_TEST(idx < (unsigned long)number_of_nodes());
if (value)
return factors1(idx,0);
else
return factors1(idx,1);
}
value_type factor_value_disagreement (unsigned long idx1, unsigned long idx2) const
{
++count;
DLIB_TEST(idx1 != idx2);
DLIB_TEST(idx1 < (unsigned long)number_of_nodes());
DLIB_TEST(idx2 < (unsigned long)number_of_nodes());
// make this function symmetric
if (idx1 > idx2)
swap(idx1,idx2);
DLIB_TEST(get_neighbor(idx1, get_neighbor_idx(idx1, idx2)) == idx2);
DLIB_TEST(get_neighbor(idx2, get_neighbor_idx(idx2, idx1)) == idx1);
// the neighbor relationship better be symmetric
DLIB_TEST(get_neighbor_idx(idx1,idx2) < number_of_neighbors(idx1));
DLIB_TEST_MSG(get_neighbor_idx(idx2,idx1) < number_of_neighbors(idx2),
"\n idx1: "<< idx1 <<
"\n idx2: "<< idx2 <<
"\n get_neighbor_idx(idx2,idx1): "<< get_neighbor_idx(idx2,idx1) <<
"\n number_of_neighbors(idx2): " << number_of_neighbors(idx2) <<
"\n nr: "<< nr <<
"\n nc: "<< nc
);
return factors2(idx1, get_neighbor_idx(idx1,idx2));
}
private:
point get_loc (
const unsigned long& idx
) const
{
return point(idx%nc, idx/nc);
}
};
// ----------------------------------------------------------------------------------------
template <typename potts_model>
void brute_force_potts_model (
potts_model& g
)
{
potts_model m(g);
const unsigned long num = (unsigned long)std::pow(2, m.number_of_nodes());
double best_score = -std::numeric_limits<double>::infinity();
for (unsigned long i = 0; i < num; ++i)
{
for (unsigned long j = 0; j < m.number_of_nodes(); ++j)
{
unsigned long T = (1)<<j;
T = (T&i);
if (T != 0)
m.set_label(j,SINK_CUT);
else
m.set_label(j,SOURCE_CUT);
}
double score = potts_model_score(m);
if (score > best_score)
{
best_score = score;
g = m;
}
}
}
// ----------------------------------------------------------------------------------------
template <typename potts_prob>
void impl_test_potts_model (
potts_prob& p
)
{
using namespace std;
double brute_force_score;
double graph_cut_score;
{
potts_prob temp(p);
brute_force_potts_model(temp);
for (unsigned long i = 0; i < temp.number_of_nodes(); ++i)
{
dlog << LTRACE << "node " << i << ": "<< (int)temp.get_label(i);
}
brute_force_score = potts_model_score(temp);
dlog << LTRACE << "brute force score: "<< brute_force_score;
}
dlog << LTRACE << "******************";
{
potts_prob temp(p);
find_max_factor_graph_potts(temp);
for (unsigned long i = 0; i < temp.number_of_nodes(); ++i)
{
dlog << LTRACE << "node " << i << ": "<< (int)temp.get_label(i);
}
graph_cut_score = potts_model_score(temp);
dlog << LTRACE << "graph cut score: "<< graph_cut_score;
}
DLIB_TEST_MSG(graph_cut_score == brute_force_score, std::abs(graph_cut_score - brute_force_score));
dlog << LTRACE << "##################";
dlog << LTRACE << "##################";
dlog << LTRACE << "##################";
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// BASIC MIN CUT STUFF
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <typename directed_graph>
void brute_force_min_cut (
directed_graph& g,
unsigned long source,
unsigned long sink
)
{
typedef typename directed_graph::edge_type edge_weight_type;
const unsigned long num = (unsigned long)std::pow(2, g.number_of_nodes());
std::vector<node_label> best_cut(g.number_of_nodes(),FREE_NODE);
edge_weight_type best_score = std::numeric_limits<edge_weight_type>::max();
for (unsigned long i = 0; i < num; ++i)
{
for (unsigned long j = 0; j < g.number_of_nodes(); ++j)
{
unsigned long T = (1)<<j;
T = (T&i);
if (T != 0)
g.node(j).data = SINK_CUT;
else
g.node(j).data = SOURCE_CUT;
}
// ignore cuts that don't label the source or sink node the way we want.
if (g.node(source).data != SOURCE_CUT ||
g.node(sink).data != SINK_CUT)
continue;
edge_weight_type score = graph_cut_score(g);
if (score < best_score)
{
best_score = score;
for (unsigned long j = 0; j < g.number_of_nodes(); ++j)
best_cut[j] = g.node(j).data;
}
}
for (unsigned long j = 0; j < g.number_of_nodes(); ++j)
g.node(j).data = best_cut[j];
}
// ----------------------------------------------------------------------------------------
template <typename directed_graph>
void print_graph(
const directed_graph& g
)
{
using namespace std;
dlog << LTRACE << "number of nodes: "<< g.number_of_nodes();
for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
{
for (unsigned long n = 0; n < g.node(i).number_of_children(); ++n)
dlog << LTRACE << i << " -(" << g.node(i).child_edge(n) << ")-> " << g.node(i).child(n).index();
}
}
template <typename directed_graph>
void copy_edge_weights (
directed_graph& dest,
const directed_graph& src
)
{
for (unsigned long i = 0; i < src.number_of_nodes(); ++i)
{
for (unsigned long n = 0; n < src.node(i).number_of_children(); ++n)
{
dest.node(i).child_edge(n) = src.node(i).child_edge(n);
}
}
}
// ----------------------------------------------------------------------------------------
template <typename graph_type>
void pick_random_source_and_sink (
dlib::rand& rnd,
const graph_type& g,
unsigned long& source,
unsigned long& sink
)
{
source = rnd.get_random_32bit_number()%g.number_of_nodes();
sink = rnd.get_random_32bit_number()%g.number_of_nodes();
while (sink == source)
sink = rnd.get_random_32bit_number()%g.number_of_nodes();
}
// ----------------------------------------------------------------------------------------
template <typename dgraph_type>
void make_random_graph(
dlib::rand& rnd,
dgraph_type& g,
unsigned long& source,
unsigned long& sink
)
{
typedef typename dgraph_type::edge_type edge_weight_type;
g.clear();
const unsigned int num_nodes = rnd.get_random_32bit_number()%7 + 2;
g.set_number_of_nodes(num_nodes);
const unsigned int num_edges = static_cast<unsigned int>(num_nodes*(num_nodes-1)/2*rnd.get_random_double() + 0.5);
// add the right number of randomly selected edges
unsigned int count = 0;
while (count < num_edges)
{
unsigned long parent = rnd.get_random_32bit_number()%g.number_of_nodes();
unsigned long child = rnd.get_random_32bit_number()%g.number_of_nodes();
if (parent != child && g.has_edge(parent, child) == false)
{
++count;
g.add_edge(parent, child);
edge(g, parent, child) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
// have to have edges both ways
swap(parent, child);
g.add_edge(parent, child);
edge(g, parent, child) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
}
}
pick_random_source_and_sink(rnd, g, source, sink);
}
// ----------------------------------------------------------------------------------------
template <typename dgraph_type>
void make_random_chain_graph(
dlib::rand& rnd,
dgraph_type& g,
unsigned long& source,
unsigned long& sink
)
{
typedef typename dgraph_type::edge_type edge_weight_type;
g.clear();
const unsigned int num_nodes = rnd.get_random_32bit_number()%7 + 2;
g.set_number_of_nodes(num_nodes);
for (unsigned long i = 1; i < g.number_of_nodes(); ++i)
{
g.add_edge(i,i-1);
g.add_edge(i-1,i);
edge(g, i, i-1) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
edge(g, i-1, i) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
}
pick_random_source_and_sink(rnd, g, source, sink);
}
// ----------------------------------------------------------------------------------------
template <typename dgraph_type>
void make_random_grid_graph(
dlib::rand& rnd,
dgraph_type& g,
unsigned long& source,
unsigned long& sink
)
/*!
ensures
- makes a grid graph like the kind used for potts models.
!*/
{
typedef typename dgraph_type::edge_type edge_weight_type;
g.clear();
const long nr = rnd.get_random_32bit_number()%2 + 2;
const long nc = rnd.get_random_32bit_number()%2 + 2;
g.set_number_of_nodes(nr*nc+2);
const rectangle rect(0,0,nc-1,nr-1);
for (long r = 0; r < nr; ++r)
{
for (long c = 0; c < nc; ++c)
{
const point p(c,r);
const unsigned long i = p.y()*nc + p.x();
const point n2(c-1,r);
if (rect.contains(n2))
{
const unsigned long j = n2.y()*nc + n2.x();
g.add_edge(i,j);
g.add_edge(j,i);
edge(g,i,j) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
edge(g,j,i) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
}
const point n4(c,r-1);
if (rect.contains(n4))
{
const unsigned long j = n4.y()*nc + n4.x();
g.add_edge(i,j);
g.add_edge(j,i);
edge(g,i,j) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
edge(g,j,i) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
}
}
}
// use the last two nodes as source and sink. Also connect them to all the other nodes.
source = g.number_of_nodes()-1;
sink = g.number_of_nodes()-2;
for (unsigned long i = 0; i < g.number_of_nodes()-2; ++i)
{
g.add_edge(i,source);
g.add_edge(source,i);
g.add_edge(i,sink);
g.add_edge(sink,i);
edge(g,i,source) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
edge(g,source,i) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
edge(g,i,sink) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
edge(g,sink,i) = static_cast<edge_weight_type>(rnd.get_random_double()*50);
}
}
// ----------------------------------------------------------------------------------------
template <typename min_cut, typename dgraph_type>
void run_test_on_graphs (
const min_cut& mc,
dgraph_type& g1,
dgraph_type& g2,
unsigned long source,
unsigned long sink
)
{
typedef typename dgraph_type::edge_type edge_weight_type;
using namespace std;
dlog << LTRACE << "number of nodes: "<< g1.number_of_nodes();
dlog << LTRACE << "is graph connected: "<< graph_is_connected(g1);
dlog << LTRACE << "has self loops: "<< graph_contains_length_one_cycle(g1);
dlog << LTRACE << "SOURCE_CUT: " << source;
dlog << LTRACE << "SINK_CUT: " << sink;
mc(g1, source, sink);
brute_force_min_cut(g2, source, sink);
print_graph(g1);
// copy the edge weights from g2 back to g1 so we can compute cut scores
copy_edge_weights(g1, g2);
DLIB_TEST(g1.number_of_nodes() == g2.number_of_nodes());
for (unsigned long i = 0; i < g1.number_of_nodes(); ++i)
{
dlog << LTRACE << "node " << i << ": " << (int)g1.node(i).data << ", " << (int)g2.node(i).data;
if (g1.node(i).data != g2.node(i).data)
{
edge_weight_type cut_score = graph_cut_score(g1);
edge_weight_type brute_force_score = graph_cut_score(g2);
dlog << LTRACE << "graph cut score: "<< cut_score;
dlog << LTRACE << "brute force score: "<< brute_force_score;
if (brute_force_score != cut_score)
print_graph(g1);
DLIB_TEST_MSG(brute_force_score == cut_score,std::abs(brute_force_score-cut_score));
}
}
}
// ----------------------------------------------------------------------------------------
template <typename min_cut, typename edge_weight_type>
void test_graph_cuts(dlib::rand& rnd)
{
typedef typename dlib::directed_graph<node_label, edge_weight_type>::kernel_1a_c dgraph_type;
// we will create two identical graphs.
dgraph_type g1, g2;
min_cut mc;
unsigned long source, sink;
dlib::rand rnd_copy(rnd);
make_random_graph(rnd,g1, source, sink);
make_random_graph(rnd_copy,g2, source, sink);
run_test_on_graphs(mc, g1, g2, source, sink);
rnd_copy = rnd;
make_random_grid_graph(rnd,g1, source, sink);
make_random_grid_graph(rnd_copy,g2, source, sink);
run_test_on_graphs(mc, g1, g2, source, sink);
rnd_copy = rnd;
make_random_chain_graph(rnd,g1, source, sink);
make_random_chain_graph(rnd_copy,g2, source, sink);
run_test_on_graphs(mc, g1, g2, source, sink);
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class graph_cuts_tester : public tester
{
public:
graph_cuts_tester (
) :
tester ("test_graph_cuts",
"Runs tests on the graph cuts tools.")
{}
dlib::rand rnd;
void perform_test (
)
{
for (int i = 0; i < 1000; ++i)
{
print_spinner();
dlog << LTRACE << "test_grpah_cuts<short> iter: " << i;
test_graph_cuts<min_cut,short>(rnd);
print_spinner();
dlog << LTRACE << "test_grpah_cuts<double> iter: " << i;
test_graph_cuts<min_cut,double>(rnd);
}
for (int k = 0; k < 300; ++k)
{
dlog << LTRACE << "image_potts_problem iter " << k;
print_spinner();
image_potts_problem p(3,3, rnd);
impl_test_potts_model(p);
}
for (int k = 0; k < 300; ++k)
{
dlog << LTRACE << "dense_potts_problem iter " << k;
print_spinner();
dense_potts_problem p(6, rnd);
impl_test_potts_model(p);
}
}
} a;
}
...@@ -60,6 +60,7 @@ SRC += find_max_factor_graph_nmplp.cpp ...@@ -60,6 +60,7 @@ SRC += find_max_factor_graph_nmplp.cpp
SRC += find_max_factor_graph_viterbi.cpp SRC += find_max_factor_graph_viterbi.cpp
SRC += geometry.cpp SRC += geometry.cpp
SRC += graph.cpp SRC += graph.cpp
SRC += graph_cuts.cpp
SRC += hash.cpp SRC += hash.cpp
SRC += hash_map.cpp SRC += hash_map.cpp
SRC += hash_set.cpp SRC += hash_set.cpp
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment