Commit d933439a authored by Davis King's avatar Davis King

Simplified the sequence_labeler interface a bit.

parent 90c9d0be
......@@ -186,6 +186,21 @@ namespace dlib
weights = 0;
}
explicit sequence_labeler(
const matrix<double,0,1>& weights_
) :
weights(weights_)
{
// make sure requires clause is not broken
DLIB_ASSERT(fe.num_features() == static_cast<unsigned long>(weights_.size()),
"\t sequence_labeler::sequence_labeler(weights_)"
<< "\n\t These sizes should match"
<< "\n\t fe.num_features(): " << fe.num_features()
<< "\n\t weights_.size(): " << weights_.size()
<< "\n\t this: " << this
);
}
sequence_labeler(
const feature_extractor& fe_,
const matrix<double,0,1>& weights_
......@@ -195,7 +210,7 @@ namespace dlib
{
// make sure requires clause is not broken
DLIB_ASSERT(fe_.num_features() == static_cast<unsigned long>(weights_.size()),
"\t sequence_labeler::sequence_labeler()"
"\t sequence_labeler::sequence_labeler(fe_,weights_)"
<< "\n\t These sizes should match"
<< "\n\t fe_.num_features(): " << fe_.num_features()
<< "\n\t weights_.size(): " << weights_.size()
......
......@@ -196,9 +196,22 @@ namespace dlib
);
/*!
ensures
- #get_weights().size() == fe.num_features()
- #get_feature_extractor() == feature_extractor()
(i.e. it will have its default value)
- #get_weights().size() == #get_feature_extractor().num_features()
- #get_weights() == 0
- #get_feature_extractor() will have an initial value for its type
!*/
explicit sequence_labeler(
const matrix<double,0,1>& weights
);
/*!
requires
- feature_extractor().num_features() == weights.size()
ensures
- #get_feature_extractor() == feature_extractor()
(i.e. it will have its default value)
- #get_weights() == weights
!*/
sequence_labeler(
......
......@@ -162,7 +162,7 @@ int main()
matrix<double,0,1> true_hmm_model_weights = log(join_cols(reshape_to_column_vector(transition_probabilities),
reshape_to_column_vector(emission_probabilities)));
sequence_labeler<feature_extractor> labeler_true(feature_extractor(), true_hmm_model_weights);
sequence_labeler<feature_extractor> labeler_true(true_hmm_model_weights);
confusion_matrix = test_sequence_labeler(labeler_true, samples, labels);
cout << "\nTrue HMM model: " << endl;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment