Commit 39ca46ca authored by Davis King's avatar Davis King

Simplified find_max_parse_cky() by making it only output a single tree.

Also added find_trees_not_rooted_with_tag().
parent f1a08876
...@@ -131,10 +131,13 @@ namespace dlib ...@@ -131,10 +131,13 @@ namespace dlib
void find_max_parse_cky ( void find_max_parse_cky (
const std::vector<T>& sequence, const std::vector<T>& sequence,
const production_rule_function& production_rules, const production_rule_function& production_rules,
std::vector<std::vector<parse_tree_element<T> > >& parse_trees std::vector<parse_tree_element<T> >& parse_tree
) )
{ {
parse_trees.clear(); parse_tree.clear();
if (sequence.size() == 0)
return;
array2d<std::map<T,double> > table(sequence.size(), sequence.size()); array2d<std::map<T,double> > table(sequence.size(), sequence.size());
array2d<std::map<T,parse_tree_element<T> > > back(sequence.size(), sequence.size()); array2d<std::map<T,parse_tree_element<T> > > back(sequence.size(), sequence.size());
typedef typename std::map<T,double>::iterator itr; typedef typename std::map<T,double>::iterator itr;
...@@ -185,31 +188,23 @@ namespace dlib ...@@ -185,31 +188,23 @@ namespace dlib
// now use back pointers to build the parse trees // now use back pointers to build the parse trees
for (long r = 0; r < back.nr(); ++r) const long r = 0;
const long c = back.nc()-1;
if (back[r][c].size() != 0)
{ {
for (long c = back.nc()-1; c > r; --c)
{
if (back[r][c].size() != 0)
{
// find the max scoring element in back[r][c] // find the max scoring element in back[r][c]
itr_b max_i = back[r][c].begin(); itr_b max_i = back[r][c].begin();
itr_b i = max_i; itr_b i = max_i;
++i; ++i;
for (; i != back[r][c].end(); ++i) for (; i != back[r][c].end(); ++i)
{ {
if (i->second.score > max_i->second.score) if (i->second.score > max_i->second.score)
max_i = i; max_i = i;
}
parse_trees.resize(parse_trees.size()+1);
parse_trees.back().reserve(c);
impl::fill_parse_tree(parse_trees.back(), max_i->second.tag, back, r, c);
r = c;
break;
}
} }
parse_tree.reserve(c);
impl::fill_parse_tree(parse_tree, max_i->second.tag, back, r, c);
} }
} }
...@@ -303,14 +298,15 @@ namespace dlib ...@@ -303,14 +298,15 @@ namespace dlib
template <typename T, typename U> template <typename T, typename U>
std::string parse_tree_to_string ( std::string parse_tree_to_string (
const std::vector<parse_tree_element<T> >& tree, const std::vector<parse_tree_element<T> >& tree,
const std::vector<U>& words const std::vector<U>& words,
const unsigned long root_idx = 0
) )
{ {
if (tree.size() == 0) if (root_idx >= tree.size())
return ""; return "";
std::ostringstream sout; std::ostringstream sout;
impl::print_parse_tree_helper<false>(tree, words, 0, sout); impl::print_parse_tree_helper<false>(tree, words, root_idx, sout);
return sout.str(); return sout.str();
} }
...@@ -319,17 +315,56 @@ namespace dlib ...@@ -319,17 +315,56 @@ namespace dlib
template <typename T, typename U> template <typename T, typename U>
std::string parse_tree_to_string_tagged ( std::string parse_tree_to_string_tagged (
const std::vector<parse_tree_element<T> >& tree, const std::vector<parse_tree_element<T> >& tree,
const std::vector<U>& words const std::vector<U>& words,
const unsigned long root_idx = 0
) )
{ {
if (tree.size() == 0) if (root_idx >= tree.size())
return ""; return "";
std::ostringstream sout; std::ostringstream sout;
impl::print_parse_tree_helper<true>(tree, words, 0, sout); impl::print_parse_tree_helper<true>(tree, words, root_idx, sout);
return sout.str(); return sout.str();
} }
// -----------------------------------------------------------------------------------------
namespace impl
{
template <typename T>
void helper_find_trees_without_tag (
const std::vector<parse_tree_element<T> >& tree,
const T& tag,
std::vector<unsigned long>& tree_roots,
unsigned long idx
)
{
if (idx < tree.size())
{
if (tree[idx].tag != tag)
{
tree_roots.push_back(idx);
}
else
{
helper_find_trees_without_tag(tree, tag, tree_roots, tree[idx].left);
helper_find_trees_without_tag(tree, tag, tree_roots, tree[idx].right);
}
}
}
}
template <typename T>
void find_trees_not_rooted_with_tag (
const std::vector<parse_tree_element<T> >& tree,
const T& tag,
std::vector<unsigned long>& tree_roots
)
{
tree_roots.clear();
impl::helper_find_trees_without_tag(tree, tag, tree_roots, 0);
}
// ----------------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------------
} }
......
...@@ -175,31 +175,30 @@ namespace dlib ...@@ -175,31 +175,30 @@ namespace dlib
void find_max_parse_cky ( void find_max_parse_cky (
const std::vector<T>& words, const std::vector<T>& words,
const production_rule_function& production_rules, const production_rule_function& production_rules,
std::vector<std::vector<parse_tree_element<T> > >& parse_trees std::vector<parse_tree_element<T> >& parse_tree
); );
/*! /*!
requires requires
- production_rule_function == a function or function object with the same - production_rule_function == a function or function object with the same
interface as example_production_rule_function defined above. interface as example_production_rule_function defined above.
- It must be possible to store T objects in a std::map.
ensures ensures
- Uses the CKY algorithm to find the most probable/highest scoring parse tree - Uses the CKY algorithm to find the most probable/highest scoring binary parse
of the given vector of words. The output is stored in #parse_trees. tree of the given vector of words.
- This function outputs a set of non-overlapping parse trees. Each parse tree - if (#parse_tree.size() == 0) then
always spans the largest number of words possible, regardless of any other - There is no parse tree, using the given production_rules, that can cover
considerations (except that the parse trees cannot have overlapping word the given word sequence.
spans). For example, this function will never select a smaller parse tree, - else
even if it would have a better score, if it can possibly build a larger tree. - #parse_tree == the highest scoring parse tree that covers all the
Therefore, this function will only output multiple parse trees if it is elements of words.
impossible to form words into a single parse tree. - #parse_tree[0] == the root node of the parse tree.
- #parse_tree[0].score == the score of the parse tree. This is the sum of
the scores of all production rules used to construct the tree.
- #parse_tree[0].begin == 0
- #parse_tree[0].end == words.size()
- This function uses production_rules() to find out what the allowed production - This function uses production_rules() to find out what the allowed production
rules are. That is, production_rules() defines all properties of the grammar rules are. That is, production_rules() defines all properties of the grammar
used by find_max_parse_cky(). used by find_max_parse_cky().
- for all valid i:
- #parse_trees[i].size() != 0
- #parse_trees[i] == the root of the i'th parse tree.
- #parse_trees[i].score == the score of the i'th parse tree.
- The i'th parse tree spans all the elements of words in the range
[#parse_trees[i].c.begin, #parse_trees[i].c.end).
!*/ !*/
// ----------------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------------
...@@ -222,7 +221,8 @@ namespace dlib ...@@ -222,7 +221,8 @@ namespace dlib
> >
std::string parse_tree_to_string ( std::string parse_tree_to_string (
const std::vector<parse_tree_element<T> >& tree, const std::vector<parse_tree_element<T> >& tree,
const std::vector<U>& words const std::vector<U>& words,
const unsigned long root_idx = 0
); );
/*! /*!
requires requires
...@@ -240,6 +240,8 @@ namespace dlib ...@@ -240,6 +240,8 @@ namespace dlib
the dog ran the dog ran
Then the output would be the string "[[the dog] ran]" Then the output would be the string "[[the dog] ran]"
- Only the sub-tree rooted at tree[root_idx] will be output. If root_idx >=
tree.size() then the empty string is returned.
throws throws
- parse_tree_to_string_error - parse_tree_to_string_error
This exception is thrown if an invalid tree is detected. This might happen This exception is thrown if an invalid tree is detected. This might happen
...@@ -255,7 +257,8 @@ namespace dlib ...@@ -255,7 +257,8 @@ namespace dlib
> >
std::string parse_tree_to_string_tagged ( std::string parse_tree_to_string_tagged (
const std::vector<parse_tree_element<T> >& tree, const std::vector<parse_tree_element<T> >& tree,
const std::vector<U>& words const std::vector<U>& words,
const unsigned long root_idx = 0
); );
/*! /*!
requires requires
...@@ -277,6 +280,8 @@ namespace dlib ...@@ -277,6 +280,8 @@ namespace dlib
the dog ran the dog ran
Then the output would be the string "[S [NP the dog] ran]" Then the output would be the string "[S [NP the dog] ran]"
- Only the sub-tree rooted at tree[root_idx] will be output. If root_idx >=
tree.size() then the empty string is returned.
throws throws
- parse_tree_to_string_error - parse_tree_to_string_error
This exception is thrown if an invalid tree is detected. This might happen This exception is thrown if an invalid tree is detected. This might happen
...@@ -284,6 +289,40 @@ namespace dlib ...@@ -284,6 +289,40 @@ namespace dlib
shorted than it is supposed to be. shorted than it is supposed to be.
!*/ !*/
// -----------------------------------------------------------------------------------------
template <
typename T
>
void find_trees_not_rooted_with_tag (
const std::vector<parse_tree_element<T> >& tree,
const T& tag,
std::vector<unsigned long>& tree_roots
);
/*!
requires
- objects of type T must be comparable using operator==
ensures
- Finds all the largest non-overlapping trees in tree that are not rooted with
the given tag.
- find_trees_not_rooted_with_tag() is useful when you want to cut a parse tree
into a bunch of sub-trees and you know that the top level of the tree is all
composed of the same kind of tag. So if you want to just "slice off" the top
of the tree where this tag lives then this function is useful for doing that.
- #tree_roots.size() == the number of sub-trees found.
- for all valid i:
- tree[#tree_roots[i]].tag != tag
- To make the operation of this function clearer, here are a few examples of
what it will do:
- if (tree[0].tag != tag) then
- #tree_roots.size() == 0
- #tree_roots[0] == 0
- else if (tree[0].tag == tag but its immediate children's tags are not equal to tag) then
- #tree_roots.size() == 2
- #tree_roots[0] == tree[0].left
- #tree_roots[1] == tree[0].right
!*/
// ----------------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------------
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment