Commit 9abedba4 authored by Davis King's avatar Davis King

Added unit tests for the grid version of find_max_factor_graph_potts().

parent 1d6fc006
......@@ -9,6 +9,8 @@
#include <dlib/directed_graph.h>
#include <dlib/graph.h>
#include <dlib/rand.h>
#include <dlib/hash.h>
#include <dlib/image_transforms.h>
#include "tester.h"
......@@ -821,6 +823,66 @@ namespace
}
// ----------------------------------------------------------------------------------------
class test_potts_grid_problem
{
public:
test_potts_grid_problem(int seed_) :seed(seed_){}
int seed;
long nr() const { return 3;}
long nc() const { return 3;}
typedef double value_type;
value_type factor_value(unsigned long idx) const
{
return ((double)murmur_hash3(&idx, sizeof(idx), seed) - std::numeric_limits<uint32>::max()/2.0)/1000.0;
}
value_type factor_value_disagreement(unsigned long idx1, unsigned long idx2) const
{
return std::abs(factor_value(idx1+idx2)/10.0);
}
};
// ----------------------------------------------------------------------------------------
template <typename prob_type>
void brute_force_potts_grid_problem(
const prob_type& prob,
array2d<unsigned char>& labels
)
{
const unsigned long num = (unsigned long)std::pow(2.0, (double)prob.nr()*prob.nc());
array2d<unsigned char> temp(prob.nr(), prob.nc());
unsigned char* data = &temp[0][0];
double best_score = -std::numeric_limits<double>::infinity();
for (unsigned long i = 0; i < num; ++i)
{
for (unsigned long j = 0; j < temp.size(); ++j)
{
unsigned long T = (1)<<j;
T = (T&i);
if (T != 0)
*(data + j) = SINK_CUT;
else
*(data + j) = SOURCE_CUT;
}
double score = potts_model_score(prob, temp);
if (score > best_score)
{
best_score = score;
assign_image(labels, temp);
}
}
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......@@ -840,6 +902,25 @@ namespace
void perform_test (
)
{
for (int i = 0; i < 500; ++i)
{
array2d<unsigned char> labels, brute_labels;
test_potts_grid_problem prob(i);
find_max_factor_graph_potts(prob, labels);
brute_force_potts_grid_problem(prob, brute_labels);
DLIB_TEST(labels.nr() == brute_labels.nr());
DLIB_TEST(labels.nc() == brute_labels.nc());
for (long r = 0; r < labels.nr(); ++r)
{
for (long c = 0; c < labels.nc(); ++c)
{
bool normal = (labels[r][c] != 0);
bool brute = (brute_labels[r][c] != 0);
DLIB_TEST(normal == brute);
}
}
}
for (int i = 0; i < 1000; ++i)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment