Commit f02e5477 authored by Davis King's avatar Davis King

Added find_max_parse_cky() and its supporting tools.

parent 626cfb90
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "optimization/max_sum_submatrix.h" #include "optimization/max_sum_submatrix.h"
#include "optimization/find_max_factor_graph_nmplp.h" #include "optimization/find_max_factor_graph_nmplp.h"
#include "optimization/find_max_factor_graph_viterbi.h" #include "optimization/find_max_factor_graph_viterbi.h"
#include "optimization/find_max_parse_cky.h"
#endif // DLIB_OPTIMIZATIOn_HEADER #endif // DLIB_OPTIMIZATIOn_HEADER
......
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_FIND_MAX_PaRSE_CKY_H__
#define DLIB_FIND_MAX_PaRSE_CKY_H__
#include "find_max_parse_cky_abstract.h"
#include <vector>
#include <string>
#include <sstream>
#include "../array2d.h"
namespace dlib
{
// -----------------------------------------------------------------------------------------
template <typename T>
struct constituent
{
unsigned long begin, end, k;
T left_tag;
T right_tag;
};
const unsigned long END_OF_TREE = 0xFFFFFFFF;
template <typename T>
struct parse_tree_element
{
constituent<T> c;
T tag; // id for the constituent corresponding to this level of the tree
unsigned long left;
unsigned long right;
double score;
};
namespace impl
{
template <typename T>
unsigned long fill_parse_tree(
std::vector<parse_tree_element<T> >& parse_tree,
const T& tag,
const array2d<std::map<T, parse_tree_element<T> > >& back,
long r, long c
)
/*!
requires
- back[r][c].size() == 0 || back[r][c].count(tag) != 0
!*/
{
// base case of the recursion
if (back[r][c].size() == 0)
{
return END_OF_TREE;
}
const unsigned long idx = parse_tree.size();
const parse_tree_element<T>& item = back[r][c].find(tag)->second;
parse_tree.push_back(item);
const long k = item.c.k;
const unsigned long idx_left = fill_parse_tree(parse_tree, item.c.left_tag, back, r, k-1);
const unsigned long idx_right = fill_parse_tree(parse_tree, item.c.right_tag, back, k, c);
parse_tree[idx].left = idx_left;
parse_tree[idx].right = idx_right;
return idx;
}
}
template <typename T, typename production_rule_function>
void find_max_parse_cky (
const std::vector<T>& sequence,
const production_rule_function& production_rules,
std::vector<std::vector<parse_tree_element<T> > >& parse_trees
)
{
parse_trees.clear();
array2d<std::map<T,double> > table(sequence.size(), sequence.size());
array2d<std::map<T,parse_tree_element<T> > > back(sequence.size(), sequence.size());
typedef typename std::map<T,double>::iterator itr;
typedef typename std::map<T,parse_tree_element<T> >::iterator itr_b;
for (long r = 0; r < table.nr(); ++r)
table[r][r][sequence[r]] = 0;
std::vector<std::pair<T,double> > possible_tags;
for (long r = table.nr()-2; r >= 0; --r)
{
for (long c = r+1; c < table.nc(); ++c)
{
for (long k = r; k < c; ++k)
{
for (itr i = table[k+1][c].begin(); i != table[k+1][c].end(); ++i)
{
for (itr j = table[r][k].begin(); j != table[r][k].end(); ++j)
{
constituent<T> con;
con.begin = r;
con.end = c+1;
con.k = k+1;
con.left_tag = j->first;
con.right_tag = i->first;
possible_tags.clear();
production_rules(sequence, con, possible_tags);
for (unsigned long m = 0; m < possible_tags.size(); ++m)
{
const double score = possible_tags[m].second + i->second + j->second;
itr match = table[r][c].find(possible_tags[m].first);
if (match == table[r][c].end() || score > match->second)
{
table[r][c][possible_tags[m].first] = score;
parse_tree_element<T> item;
item.c = con;
item.score = score;
item.tag = possible_tags[m].first;
back[r][c][possible_tags[m].first] = item;
}
}
}
}
}
}
}
// now use back pointers to build the parse trees
for (long r = 0; r < back.nr(); ++r)
{
for (long c = back.nc()-1; c > r; --c)
{
if (back[r][c].size() != 0)
{
// find the max scoring element in back[r][c]
itr_b max_i = back[r][c].begin();
itr_b i = max_i;
++i;
for (; i != back[r][c].end(); ++i)
{
if (i->second.score > max_i->second.score)
max_i = i;
}
parse_trees.resize(parse_trees.size()+1);
parse_trees.back().reserve(c);
impl::fill_parse_tree(parse_trees.back(), max_i->second.tag, back, r, c);
r = c;
break;
}
}
}
}
// -----------------------------------------------------------------------------------------
class parse_tree_to_string_error : public error
{
public:
parse_tree_to_string_error(const std::string& str): error(str) {}
};
namespace impl
{
template <bool enabled, typename T>
typename enable_if_c<enabled>::type conditional_print(
const T& item,
std::ostream& out
) { out << item << " "; }
template <bool enabled, typename T>
typename disable_if_c<enabled>::type conditional_print(
const T& ,
std::ostream&
) { }
template <bool print_tag, typename T, typename U >
void print_parse_tree_helper (
const std::vector<parse_tree_element<T> >& tree,
const std::vector<U>& items,
unsigned long i,
std::ostream& out
)
{
out << "[";
bool left_recurse = false;
// Only print if we are supposed to. Doing it this funny way avoids compiler
// errors in parse_tree_to_string() for the case where tag isn't
// printable.
conditional_print<print_tag>(tree[i].tag, out);
if (tree[i].left < tree.size())
{
left_recurse = true;
print_parse_tree_helper<print_tag>(tree, items, tree[i].left, out);
}
else
{
if (tree[i].c.begin < items.size())
{
out << items[tree[i].c.begin] << " ";
}
else
{
std::ostringstream sout;
sout << "Parse tree refers to element " << tree[i].c.begin
<< " of sequence which is only of size " << items.size() << ".";
throw parse_tree_to_string_error(sout.str());
}
}
if (tree[i].right < tree.size())
{
if (left_recurse == true)
out << " ";
print_parse_tree_helper<print_tag>(tree, items, tree[i].right, out);
}
else
{
if (tree[i].c.k < items.size())
{
out << items[tree[i].c.k];
}
else
{
std::ostringstream sout;
sout << "Parse tree refers to element " << tree[i].c.k
<< " of sequence which is only of size " << items.size() << ".";
throw parse_tree_to_string_error(sout.str());
}
}
out << "]";
}
}
// -----------------------------------------------------------------------------------------
template <typename T, typename U>
std::string parse_tree_to_string (
const std::vector<parse_tree_element<T> >& tree,
const std::vector<U>& items
)
{
if (tree.size() == 0)
return "";
std::ostringstream sout;
impl::print_parse_tree_helper<false>(tree, items, 0, sout);
return sout.str();
}
// -----------------------------------------------------------------------------------------
template <typename T, typename U>
std::string parse_tree_to_string_tagged (
const std::vector<parse_tree_element<T> >& tree,
const std::vector<U>& items
)
{
if (tree.size() == 0)
return "";
std::ostringstream sout;
impl::print_parse_tree_helper<true>(tree, items, 0, sout);
return sout.str();
}
// -----------------------------------------------------------------------------------------
}
#endif // DLIB_FIND_MAX_PaRSE_CKY_H__
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_FIND_MAX_PARsE_CKY_ABSTRACT_H__
#ifdef DLIB_FIND_MAX_PARsE_CKY_ABSTRACT_H__
#include <vector>
#include <string>
namespace dlib
{
// -----------------------------------------------------------------------------------------
template <typename T>
struct constituent
{
/*!
WHAT THIS OBJECT REPRESENTS
!*/
unsigned long begin, end, k;
T left_tag;
T right_tag;
};
const unsigned long END_OF_TREE = 0xFFFFFFFF;
template <typename T>
struct parse_tree_element
{
/*!
WHAT THIS OBJECT REPRESENTS
!*/
constituent<T> c;
T tag; // id for the constituent corresponding to this level of the tree
// subtrees. These are the index values into the std::vector that contains all the parse_tree_elements.
unsigned long left;
unsigned long right;
double score; // score for this tree
};
// -----------------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------------
void example_production_rule_function (
const std::vector<T>& sequence,
const constituent<T>& c,
std::vector<std::pair<T,double> >& possible_tags
)
/*!
requires
- 0 <= c.begin < c.k < c.end <= sequence.size()
- possible_tags.size() == 0
ensures
- finds all the production rules that can turn c into a single non-terminal.
Puts the IDs of these rules and their scores into possible_tags.
- Note that example_production_rule_function() is not a real function. It is
here just to show you how to define production rule producing functions
for use with the find_max_parse_cky() routine defined below.
!*/
template <
typename T,
typename production_rule_function
>
void find_max_parse_cky (
const std::vector<T>& sequence,
const production_rule_function& production_rules,
std::vector<std::vector<parse_tree_element<T> > >& parse_trees
);
/*!
requires
- production_rule_function == a function or function object with the same
interface as example_production_rule_function defined above.
!*/
// -----------------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------------
class parse_tree_to_string_error : public error
{
/*!
WHAT THIS OBJECT REPRESENTS
!*/
};
// -----------------------------------------------------------------------------------------
template <typename T, typename U>
std::string parse_tree_to_string (
const std::vector<parse_tree_element<T> >& tree,
const std::vector<U>& items
);
/*!
ensures
-
!*/
// -----------------------------------------------------------------------------------------
template <typename T, typename U>
std::string parse_tree_to_string_tagged (
const std::vector<parse_tree_element<T> >& tree,
const std::vector<U>& items
);
/*!
ensures
-
!*/
// -----------------------------------------------------------------------------------------
}
#endif // DLIB_FIND_MAX_PARsE_CKY_ABSTRACT_H__
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment