Commit 46c02483 authored by Davis King's avatar Davis King

Fixed some corner cases in find_max_factor_graph_viterbi() and also

added unit tests.
parent 108c46e5
......@@ -76,6 +76,36 @@ namespace dlib
<< "\n\t std::numeric_limits<unsigned long>::max(): " << std::numeric_limits<unsigned long>::max()
);
if (prob.number_of_nodes() == 0)
{
map_assignment.clear();
return;
}
if (order == 0)
{
map_assignment.resize(prob.number_of_nodes());
for (unsigned long i = 0; i < map_assignment.size(); ++i)
{
matrix<unsigned long,1,1> node_state;
unsigned long best_state = 0;
double best_val = 0;
for (unsigned long s = 0; s < num_states; ++s)
{
node_state(0) = s;
const double temp = prob.factor_value(i,node_state);
if (temp > best_val)
{
best_val = temp;
best_state = s;
}
}
map_assignment[i] = best_state;
}
return;
}
const unsigned long trellis_size = static_cast<unsigned long>(std::pow(num_states,order));
unsigned long init_ring_size = 1;
......
......@@ -39,6 +39,7 @@ set (tests
entropy_coder.cpp
entropy_encoder_model.cpp
find_max_factor_graph_nmplp.cpp
find_max_factor_graph_viterbi.cpp
geometry.cpp
graph.cpp
hash.cpp
......
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include <dlib/optimization.h>
#include <dlib/rand.h>
#include "tester.h"
namespace
{
using namespace test;
using namespace dlib;
using namespace std;
logger dlog("test.find_max_factor_graph_viterbi");
// ----------------------------------------------------------------------------------------
dlib::rand rnd;
// ----------------------------------------------------------------------------------------
template <
unsigned long O,
unsigned long NS,
unsigned long num_nodes
>
class map_problem
{
public:
const static unsigned long order = O;
const static unsigned long num_states = NS;
map_problem()
{
data = randm(number_of_nodes(),std::pow(num_states,order+1), rnd);
}
unsigned long number_of_nodes (
) const
{
return num_nodes;
}
template <
typename EXP
>
double factor_value (
unsigned long node_id,
const matrix_exp<EXP>& node_states
) const
{
if (node_states.size() == 1)
return data(node_id, node_states(0));
else if (node_states.size() == 2)
return data(node_id, node_states(0) + node_states(1)*num_states);
else if (node_states.size() == 3)
return data(node_id, (node_states(0) + node_states(1)*num_states)*num_states + node_states(2));
else
return data(node_id, ((node_states(0) + node_states(1)*num_states)*num_states + node_states(2))*num_states + node_states(3));
}
matrix<double> data;
};
// ----------------------------------------------------------------------------------------
template <
typename map_problem
>
void brute_force_find_max_factor_graph_viterbi (
const map_problem& prob,
std::vector<unsigned long>& map_assignment
)
{
using namespace dlib::impl;
const int order = map_problem::order;
const int num_states = map_problem::num_states;
map_assignment.resize(prob.number_of_nodes());
double best_score = -std::numeric_limits<double>::infinity();
matrix<unsigned long,1,0> node_states;
node_states.set_size(prob.number_of_nodes());
node_states = 0;
do
{
double score = 0;
for (unsigned long i = 0; i < prob.number_of_nodes(); ++i)
{
score += prob.factor_value(i, (colm(node_states,range(i,i-std::min<int>(order,i)))));
}
if (score > best_score)
{
for (unsigned long i = 0; i < map_assignment.size(); ++i)
map_assignment[i] = node_states(i);
best_score = score;
}
} while(advance_state(node_states,num_states));
}
// ----------------------------------------------------------------------------------------
template <
unsigned long order,
unsigned long num_states,
unsigned long num_nodes
>
void do_test()
{
dlog << LINFO << "order: "<< order
<< " num_states: " << num_states
<< " num_nodes: " << num_nodes;
for (int i = 0; i < 25; ++i)
{
print_spinner();
map_problem<order,num_states,num_nodes> prob;
std::vector<unsigned long> assign, assign2;
brute_force_find_max_factor_graph_viterbi(prob, assign);
find_max_factor_graph_viterbi(prob, assign2);
DLIB_TEST_MSG(vector_to_matrix(assign) == vector_to_matrix(assign2),
trans(vector_to_matrix(assign))
<< trans(vector_to_matrix(assign2))
);
}
}
// ----------------------------------------------------------------------------------------
class test_find_max_factor_graph_viterbi : public tester
{
public:
test_find_max_factor_graph_viterbi (
) :
tester ("test_find_max_factor_graph_viterbi",
"Runs tests on the find_max_factor_graph_viterbi routine.")
{}
void perform_test (
)
{
do_test<1,3,0>();
do_test<1,3,1>();
do_test<1,3,2>();
do_test<0,3,2>();
do_test<1,3,8>();
do_test<2,3,7>();
do_test<3,3,8>();
do_test<4,3,8>();
do_test<0,3,8>();
do_test<4,3,1>();
do_test<4,3,0>();
do_test<0,3,0>();
do_test<1,2,8>();
do_test<2,2,7>();
do_test<3,2,8>();
do_test<0,2,8>();
do_test<1,1,8>();
do_test<2,1,8>();
do_test<3,1,8>();
do_test<0,1,8>();
}
} a;
}
......@@ -55,6 +55,7 @@ SRC += entropy_coder.cpp
SRC += entropy_encoder_model.cpp
SRC += geometry.cpp
SRC += find_max_factor_graph_nmplp.cpp
SRC += find_max_factor_graph_viterbi.cpp
SRC += graph.cpp
SRC += hash.cpp
SRC += hash_map.cpp
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment