Commit 000b7e70 authored by Davis King's avatar Davis King

Changed find_max_factor_graph_nmplp() to use a simple hash table instead

of std::map.  This is significantly faster.  I also added some missing asserts
to validate that the map problems supplied by the user are valid.
parent f1b5fc97
...@@ -7,11 +7,123 @@ ...@@ -7,11 +7,123 @@
#include <vector> #include <vector>
#include <map> #include <map>
#include "../matrix.h" #include "../matrix.h"
#include "../hash.h"
namespace dlib namespace dlib
{ {
// ----------------------------------------------------------------------------------------
namespace impl
{
class simple_hash_map
{
public:
simple_hash_map(
) :
scan_dist(6)
{
data.resize(5000);
}
void insert (
const unsigned long a,
const unsigned long b,
const unsigned long value
)
/*!
requires
- a != std::numeric_limits<unsigned long>::max()
ensures
- #(*this)(a,b) == value
!*/
{
const unsigned long block[2] = {a,b};
const uint32 h = murmur_hash3(&block[0], sizeof(block))%(data.size()-scan_dist);
const unsigned long empty_bucket = std::numeric_limits<unsigned long>::max();
for (uint32 i = 0; i < scan_dist; ++i)
{
if (data[i+h].key1 == empty_bucket)
{
data[i+h].key1 = a;
data[i+h].key2 = b;
data[i+h].value = value;
return;
}
}
// if we get this far it means the hash table is filling up. So double its size.
std::vector<bucket> new_data;
new_data.resize(data.size()*2);
new_data.swap(data);
for (uint32 i = 0; i < new_data.size(); ++i)
{
if (new_data[i].key1 != empty_bucket)
{
insert(new_data[i].key1, new_data[i].key2, new_data[i].value);
}
}
insert(a,b,value);
}
unsigned long operator() (
const unsigned long a,
const unsigned long b
) const
/*!
requires
- this->insert(a,b,some_value) has been called
ensures
- returns the value stored at key (a,b)
!*/
{
DLIB_ASSERT(a != b, "An invalid map_problem was given to find_max_factor_graph_nmplp()."
<< "\nNode " << a << " is listed as being a neighbor with itself, which is illegal.");
const unsigned long block[2] = {a,b};
uint32 h = murmur_hash3(&block[0], sizeof(block))%(data.size()-scan_dist);
for (unsigned long i = 0; i < scan_dist; ++i)
{
if (data[h].key1 == a && data[h].key2 == b)
{
return data[h].value;
}
++h;
}
// this should never happen (since this function requires (a,b) to be in the hash table
DLIB_ASSERT(false, "An invalid map_problem was given to find_max_factor_graph_nmplp()."
<< "\nThe nodes in the map_problem are inconsistent because node "<<a<<" is in the neighbor list"
<< "\nof node "<<b<< " but node "<<b<<" isn't in the neighbor list of node "<<a<<". The neighbor relationship"
<< "\nis supposed to be symmetric."
);
return 0;
}
private:
struct bucket
{
// having max() in key1 indicates that the bucket isn't used.
bucket() : key1(std::numeric_limits<unsigned long>::max()) {}
unsigned long key1;
unsigned long key2;
unsigned long value;
};
std::vector<bucket> data;
const unsigned int scan_dist;
};
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -54,7 +166,7 @@ namespace dlib ...@@ -54,7 +166,7 @@ namespace dlib
std::vector<double> gamma_elements; std::vector<double> gamma_elements;
gamma_elements.reserve(prob.number_of_nodes()*prob.num_states(prob.begin())*3); gamma_elements.reserve(prob.number_of_nodes()*prob.num_states(prob.begin())*3);
std::map<std::pair<unsigned long, unsigned long>, unsigned long> gamma_idx; impl::simple_hash_map gamma_idx;
...@@ -67,7 +179,7 @@ namespace dlib ...@@ -67,7 +179,7 @@ namespace dlib
{ {
const unsigned long id_j = prob.node_id(j); const unsigned long id_j = prob.node_id(j);
gamma_idx[std::make_pair(id_i,id_j)] = gamma_elements.size(); gamma_idx.insert(id_i, id_j, gamma_elements.size());
const unsigned long num_states_xj = prob.num_states(j); const unsigned long num_states_xj = prob.num_states(j);
...@@ -127,8 +239,8 @@ namespace dlib ...@@ -127,8 +239,8 @@ namespace dlib
for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j) for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j)
{ {
const unsigned long id_j = prob.node_id(j); const unsigned long id_j = prob.node_id(j);
double* const gamma_ji = &gamma_elements[gamma_idx[std::make_pair(id_j,id_i)]]; double* const gamma_ji = &gamma_elements[gamma_idx(id_j,id_i)];
double* const gamma_ij = &gamma_elements[gamma_idx[std::make_pair(id_i,id_j)]]; double* const gamma_ij = &gamma_elements[gamma_idx(id_i,id_j)];
const unsigned long num_states_xj = prob.num_states(j); const unsigned long num_states_xj = prob.num_states(j);
...@@ -149,7 +261,7 @@ namespace dlib ...@@ -149,7 +261,7 @@ namespace dlib
const unsigned long id_k = prob.node_id(k); const unsigned long id_k = prob.node_id(k);
++num_neighbors; ++num_neighbors;
const double* const gamma_ki = &gamma_elements[gamma_idx[std::make_pair(id_k,id_i)]]; const double* const gamma_ki = &gamma_elements[gamma_idx(id_k,id_i)];
sum_temp += gamma_ki[xi]; sum_temp += gamma_ki[xi];
} }
...@@ -184,7 +296,7 @@ namespace dlib ...@@ -184,7 +296,7 @@ namespace dlib
for (unsigned long xi = 0; xi < b.size(); ++xi) for (unsigned long xi = 0; xi < b.size(); ++xi)
{ {
const double* const gamma_ki = &gamma_elements[gamma_idx[std::make_pair(id_k,id_i)]]; const double* const gamma_ki = &gamma_elements[gamma_idx(id_k,id_i)];
b[xi] += gamma_ki[xi]; b[xi] += gamma_ki[xi];
} }
} }
......
...@@ -24,6 +24,12 @@ namespace dlib ...@@ -24,6 +24,12 @@ namespace dlib
looking at here is simply the interface definition for a map problem. looking at here is simply the interface definition for a map problem.
You must implement your own version of this object for the problem You must implement your own version of this object for the problem
you wish to solve and then pass it to the find_max_factor_graph_nmplp() routine. you wish to solve and then pass it to the find_max_factor_graph_nmplp() routine.
Note also that a factor graph should not have any nodes which are
neighbors with themselves. Additionally, the graph is undirected. This
mean that if A is a neighbor of B then B must be a neighbor of A for
the map problem to be valid.
!*/ !*/
public: public:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment