Commit 14ae1d76 authored by Davis King's avatar Davis King

Made the sequence trainer use the reject_labeling() information from

the feature extractor.  Also added the necessary input validation
to make sure this feature doesn't get misused.
parent 1e1d7fc8
...@@ -143,11 +143,14 @@ namespace dlib ...@@ -143,11 +143,14 @@ namespace dlib
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(is_sequence_labeling_problem(x,y) == true, DLIB_ASSERT(is_sequence_labeling_problem(x,y) == true &&
contains_invalid_labeling(get_feature_extractor(), x, y) == false,
"\t sequence_labeler structural_sequence_labeling_trainer::train(x,y)" "\t sequence_labeler structural_sequence_labeling_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
<< "\n\t x.size(): " << x.size() << "\n\t x.size(): " << x.size()
<< "\n\t is_sequence_labeling_problem(x,y): " << is_sequence_labeling_problem(x,y) << "\n\t is_sequence_labeling_problem(x,y): " << is_sequence_labeling_problem(x,y)
<< "\n\t contains_invalid_labeling(get_feature_extractor(),x,y): " << contains_invalid_labeling(get_feature_extractor(),x,y)
<< "\n\t this: " << this
); );
#ifdef ENABLE_ASSERTS #ifdef ENABLE_ASSERTS
......
...@@ -197,7 +197,8 @@ namespace dlib ...@@ -197,7 +197,8 @@ namespace dlib
) const; ) const;
/*! /*!
requires requires
- is_sequence_labeling_problem(x, y) - is_sequence_labeling_problem(x, y) == true
- contains_invalid_labeling(get_feature_extractor(), x, y) == false
- for all valid i and j: y[i][j] < num_labels() - for all valid i and j: y[i][j] < num_labels()
ensures ensures
- Uses the structural_svm_sequence_labeling_problem to train a - Uses the structural_svm_sequence_labeling_problem to train a
......
...@@ -85,11 +85,13 @@ namespace dlib ...@@ -85,11 +85,13 @@ namespace dlib
fe(fe_) fe(fe_)
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(is_sequence_labeling_problem(samples,labels) == true, DLIB_ASSERT(is_sequence_labeling_problem(samples,labels) == true &&
contains_invalid_labeling(fe, samples, labels) == false,
"\t structural_svm_sequence_labeling_problem::structural_svm_sequence_labeling_problem()" "\t structural_svm_sequence_labeling_problem::structural_svm_sequence_labeling_problem()"
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
<< "\n\t samples.size(): " << samples.size() << "\n\t samples.size(): " << samples.size()
<< "\n\t is_sequence_labeling_problem(samples,labels): " << is_sequence_labeling_problem(samples,labels) << "\n\t is_sequence_labeling_problem(samples,labels): " << is_sequence_labeling_problem(samples,labels)
<< "\n\t contains_invalid_labeling(fe,samples,labels): " << contains_invalid_labeling(fe,samples,labels)
<< "\n\t this: " << this << "\n\t this: " << this
); );
...@@ -187,12 +189,8 @@ namespace dlib ...@@ -187,12 +189,8 @@ namespace dlib
const matrix_exp<EXP>& node_states const matrix_exp<EXP>& node_states
) const ) const
{ {
/* TODO, uncomment this and setup some input validation to catch when rejection
is used incorrectly.
if (dlib::impl::call_reject_labeling_if_exists(fe, sequence, node_states, node_id)) if (dlib::impl::call_reject_labeling_if_exists(fe, sequence, node_states, node_id))
return -std::numeric_limits<double>::infinity(); return -std::numeric_limits<double>::infinity();
*/
double loss = 0; double loss = 0;
if (node_states(0) != label[node_id]) if (node_states(0) != label[node_id])
......
...@@ -50,7 +50,8 @@ namespace dlib ...@@ -50,7 +50,8 @@ namespace dlib
); );
/*! /*!
requires requires
- is_sequence_labeling_problem(samples, labels) - is_sequence_labeling_problem(samples, labels) == true
- contains_invalid_labeling(fe, samples, labels) == false
- for all valid i and j: labels[i][j] < fe.num_labels() - for all valid i and j: labels[i][j] < fe.num_labels()
ensures ensures
- This object attempts to learn a mapping from the given samples to the - This object attempts to learn a mapping from the given samples to the
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment