Commit 14ae1d76 authored by Davis King's avatar Davis King

Made the sequence trainer use the reject_labeling() information from

the feature extractor.  Also added the necessary input validation
to make sure this feature doesn't get misused.
parent 1e1d7fc8
......@@ -143,11 +143,14 @@ namespace dlib
{
// make sure requires clause is not broken
DLIB_ASSERT(is_sequence_labeling_problem(x,y) == true,
DLIB_ASSERT(is_sequence_labeling_problem(x,y) == true &&
contains_invalid_labeling(get_feature_extractor(), x, y) == false,
"\t sequence_labeler structural_sequence_labeling_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.size(): " << x.size()
<< "\n\t is_sequence_labeling_problem(x,y): " << is_sequence_labeling_problem(x,y)
<< "\n\t contains_invalid_labeling(get_feature_extractor(),x,y): " << contains_invalid_labeling(get_feature_extractor(),x,y)
<< "\n\t this: " << this
);
#ifdef ENABLE_ASSERTS
......
......@@ -197,7 +197,8 @@ namespace dlib
) const;
/*!
requires
- is_sequence_labeling_problem(x, y)
- is_sequence_labeling_problem(x, y) == true
- contains_invalid_labeling(get_feature_extractor(), x, y) == false
- for all valid i and j: y[i][j] < num_labels()
ensures
- Uses the structural_svm_sequence_labeling_problem to train a
......
......@@ -85,11 +85,13 @@ namespace dlib
fe(fe_)
{
// make sure requires clause is not broken
DLIB_ASSERT(is_sequence_labeling_problem(samples,labels) == true,
DLIB_ASSERT(is_sequence_labeling_problem(samples,labels) == true &&
contains_invalid_labeling(fe, samples, labels) == false,
"\t structural_svm_sequence_labeling_problem::structural_svm_sequence_labeling_problem()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t samples.size(): " << samples.size()
<< "\n\t is_sequence_labeling_problem(samples,labels): " << is_sequence_labeling_problem(samples,labels)
<< "\n\t contains_invalid_labeling(fe,samples,labels): " << contains_invalid_labeling(fe,samples,labels)
<< "\n\t this: " << this
);
......@@ -187,12 +189,8 @@ namespace dlib
const matrix_exp<EXP>& node_states
) const
{
/* TODO, uncomment this and setup some input validation to catch when rejection
is used incorrectly.
if (dlib::impl::call_reject_labeling_if_exists(fe, sequence, node_states, node_id))
return -std::numeric_limits<double>::infinity();
*/
double loss = 0;
if (node_states(0) != label[node_id])
......
......@@ -50,7 +50,8 @@ namespace dlib
);
/*!
requires
- is_sequence_labeling_problem(samples, labels)
- is_sequence_labeling_problem(samples, labels) == true
- contains_invalid_labeling(fe, samples, labels) == false
- for all valid i and j: labels[i][j] < fe.num_labels()
ensures
- This object attempts to learn a mapping from the given samples to the
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment