Commit 1b69ed2e authored by Davis King's avatar Davis King

Added another overload of find_max_factor_graph_potts() that works on

graphs that are regular grids.
parent 6bf0920d
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "general_potts_problem.h" #include "general_potts_problem.h"
#include "../algs.h" #include "../algs.h"
#include "../graph_utils.h" #include "../graph_utils.h"
#include "../array2d.h"
namespace dlib namespace dlib
{ {
...@@ -410,6 +411,139 @@ namespace dlib ...@@ -410,6 +411,139 @@ namespace dlib
} }
}; };
// ----------------------------------------------------------------------------------------
template <
typename label_image_type,
typename image_potts_model
>
class potts_grid_problem
{
label_image_type& label_img;
long nc;
long num_nodes;
unsigned char* labels;
const image_potts_model& model;
public:
const static unsigned long max_number_of_neighbors = 4;
potts_grid_problem (
label_image_type& label_img_,
const image_potts_model& image_potts_model_
) :
label_img(label_img_),
model(image_potts_model_)
{
num_nodes = model.nr()*model.nc();
nc = model.nc();
labels = &label_img[0][0];
}
unsigned long number_of_nodes (
) const { return num_nodes; }
unsigned long number_of_neighbors (
unsigned long
) const
{
return 4;
}
unsigned long get_neighbor_idx (
long node_id1,
long node_id2
) const
{
long diff = node_id2-node_id1;
if (diff > nc)
diff -= (long)number_of_nodes();
else if (diff < -nc)
diff += (long)number_of_nodes();
if (diff == 1)
return 0;
else if (diff == -1)
return 1;
else if (diff == nc)
return 2;
else
return 3;
}
unsigned long get_neighbor (
long node_id,
long idx
) const
{
switch(idx)
{
case 0:
{
long temp = node_id+1;
if (temp < (long)number_of_nodes())
return temp;
else
return temp - (long)number_of_nodes();
}
case 1:
{
long temp = node_id-1;
if (node_id >= 1)
return temp;
else
return temp + (long)number_of_nodes();
}
case 2:
{
long temp = node_id+nc;
if (temp < (long)number_of_nodes())
return temp;
else
return temp - (long)number_of_nodes();
}
case 3:
{
long temp = node_id-nc;
if (node_id >= nc)
return temp;
else
return temp + (long)number_of_nodes();
}
}
return 0;
}
void set_label (
const unsigned long& idx,
node_label value
)
{
*(labels+idx) = value;
}
node_label get_label (
const unsigned long& idx
) const
{
return *(labels+idx);
}
typedef typename image_potts_model::value_type value_type;
value_type factor_value (unsigned long idx) const
{
return model.factor_value(idx);
}
value_type factor_value_disagreement (unsigned long idx1, unsigned long idx2) const
{
return model.factor_value_disagreement(idx1,idx2);
}
};
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -531,6 +665,29 @@ namespace dlib ...@@ -531,6 +665,29 @@ namespace dlib
return score; return score;
} }
// ----------------------------------------------------------------------------------------
template <
typename potts_grid_problem,
typename mem_manager
>
typename potts_grid_problem::value_type potts_model_score (
const potts_grid_problem& prob,
const array2d<node_label,mem_manager>& labels
)
{
DLIB_ASSERT(prob.nr() == labels.nr() && prob.nc() == labels.nc(),
"\t value_type potts_model_score(prob,labels)"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t prob.nr(): " << labels.nr()
<< "\n\t prob.nc(): " << labels.nc()
);
typedef array2d<node_label,mem_manager> image_type;
// This const_cast is ok because the model object won't actually modify labels
dlib::impl::potts_grid_problem<image_type,potts_grid_problem> model(const_cast<image_type&>(labels),prob);
return potts_model_score(model);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -640,6 +797,23 @@ namespace dlib ...@@ -640,6 +797,23 @@ namespace dlib
} }
// ----------------------------------------------------------------------------------------
template <
typename potts_grid_problem,
typename mem_manager
>
void find_max_factor_graph_potts (
const potts_grid_problem& prob,
array2d<node_label,mem_manager>& labels
)
{
typedef array2d<node_label,mem_manager> image_type;
labels.set_size(prob.nr(), prob.nc());
dlib::impl::potts_grid_problem<image_type,potts_grid_problem> model(labels,prob);
find_max_factor_graph_potts(model);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "../matrix.h" #include "../matrix.h"
#include "min_cut_abstract.h" #include "min_cut_abstract.h"
#include "../graph_utils.h" #include "../graph_utils.h"
#include "../array2d/array2d_kernel_abstract.h"
namespace dlib namespace dlib
{ {
...@@ -159,6 +160,83 @@ namespace dlib ...@@ -159,6 +160,83 @@ namespace dlib
}; };
// ----------------------------------------------------------------------------------------
struct potts_grid_problem
{
/*!
WHAT THIS OBJECT REPRESENTS
This object is a specialization of a potts_problem to the case where
the graph is a regular grid where each node is connected to its four
neighbors. An example of this is an image where each pixel is a node
and is connected to its four immediate neighboring pixels. Therefore,
this object defines the interface this special kind of MAP problem
must implement if it is to be solved by the find_max_factor_graph_potts(potts_grid_problem,array2d)
routine defined at the end of this file.
Note that all nodes always have four neighbors, even nodes on the edge
of the graph. This is because these border nodes are connected to
the border nodes on the other side of the graph. That is, the graph
"wraps" around at the borders.
!*/
// This typedef should be for a type like int or double. It
// must also be capable of representing signed values.
typedef an_integer_or_real_type value_type;
long nr(
) const;
/*!
ensures
- returns the number of rows in the grid
!*/
long nc(
) const;
/*!
ensures
- returns the number of columns in the grid
!*/
value_type factor_value (
unsigned long idx
) const;
/*!
requires
- idx < nr()*nc()
ensures
- The grid is represented in row-major-order format. Therefore, idx
identifies a node according to its position in the row-major-order
representation of the grid graph. Or in other words, idx corresponds
to the following row and column location:
- row == idx/nc()
- col == idx%nc()
- returns a value which indicates how "good" it is to assign the idx-th
node the label of true. The larger the value, the more desirable it is
to give it this label. Similarly, a negative value indicates that it is
better to give the node a label of false.
!*/
value_type factor_value_disagreement (
unsigned long idx1,
unsigned long idx2
) const;
/*!
requires
- idx1 < nr()*nc()
- idx2 < nr()*nc()
- idx1 != idx2
- the idx1-th node and idx2-th node are neighbors in the grid graph.
ensures
- returns a number >= 0. This is the penalty for giving node idx1 and idx2
different labels. Larger values indicate a larger penalty.
- this function is symmetric. That is, it is true that:
factor_value_disagreement(i,j) == factor_value_disagreement(j,i)
!*/
};
// ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -230,6 +308,44 @@ namespace dlib ...@@ -230,6 +308,44 @@ namespace dlib
- Then this function returns F - D - Then this function returns F - D
!*/ !*/
// ----------------------------------------------------------------------------------------
template <
typename potts_grid_problem,
typename mem_manager
>
typename potts_grid_problem::value_type potts_model_score (
const potts_grid_problem& prob,
const array2d<node_label,mem_manager>& labels
);
/*!
requires
- prob.nr() == labels.nr()
- prob.nc() == labels.nc()
- potts_grid_problem == an object with an interface compatible with the
potts_grid_problem object defined above.
- for all valid i and j:
- prob.factor_value_disagreement(i,j) >= 0
- prob.factor_value_disagreement(i,j) == prob.factor_value_disagreement(j,i)
ensures
- computes the model score for the given potts_grid_problem. We define this
precisely below:
- let L(i) == the boolean label of the i-th variable in prob. Or in other
words, L(i) == (labels[i/labels.nc()][i%labels.nc()] != 0).
- let F == the sum of values of prob.factor_value(i) for only i values
where L(i) == true.
- Let D == the sum of values of prob.factor_value_disagreement(i,j)
for only i and j values which meet the following conditions:
- i and j are neighbors in the graph defined by prob, that is,
it is valid to call prob.factor_value_disagreement(i,j).
- L(i) != L(j)
- i < j
(i.e. We want to make sure to only count the edge between i and j once)
- Then this function returns F - D
!*/
// ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -283,6 +399,33 @@ namespace dlib ...@@ -283,6 +399,33 @@ namespace dlib
- the factor_value_disagreement(i,j) is stored in edge(g,i,j). - the factor_value_disagreement(i,j) is stored in edge(g,i,j).
!*/ !*/
// ----------------------------------------------------------------------------------------
template <
typename potts_grid_problem,
typename mem_manager
>
void find_max_factor_graph_potts (
const potts_grid_problem& prob,
array2d<node_label,mem_manager>& labels
);
/*!
requires
- potts_grid_problem == an object with an interface compatible with the
potts_grid_problem object defined above.
- for all valid i and j:
- prob.factor_value_disagreement(i,j) >= 0
- prob.factor_value_disagreement(i,j) == prob.factor_value_disagreement(j,i)
ensures
- This routine solves a version of a potts problem where the graph is a
regular grid where each node is connected to its four immediate neighbors.
In particular, this means that this function finds the assignments
to all the labels in prob which maximizes potts_model_score(prob,#labels).
- The optimal labels are stored in #labels.
- #labels.nr() == prob.nr()
- #labels.nc() == prob.nc()
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment