Commit d933439a authored by Davis King's avatar Davis King

Simplified the sequence_labeler interface a bit.

parent 90c9d0be
...@@ -186,6 +186,21 @@ namespace dlib ...@@ -186,6 +186,21 @@ namespace dlib
weights = 0; weights = 0;
} }
explicit sequence_labeler(
const matrix<double,0,1>& weights_
) :
weights(weights_)
{
// make sure requires clause is not broken
DLIB_ASSERT(fe.num_features() == static_cast<unsigned long>(weights_.size()),
"\t sequence_labeler::sequence_labeler(weights_)"
<< "\n\t These sizes should match"
<< "\n\t fe.num_features(): " << fe.num_features()
<< "\n\t weights_.size(): " << weights_.size()
<< "\n\t this: " << this
);
}
sequence_labeler( sequence_labeler(
const feature_extractor& fe_, const feature_extractor& fe_,
const matrix<double,0,1>& weights_ const matrix<double,0,1>& weights_
...@@ -195,7 +210,7 @@ namespace dlib ...@@ -195,7 +210,7 @@ namespace dlib
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(fe_.num_features() == static_cast<unsigned long>(weights_.size()), DLIB_ASSERT(fe_.num_features() == static_cast<unsigned long>(weights_.size()),
"\t sequence_labeler::sequence_labeler()" "\t sequence_labeler::sequence_labeler(fe_,weights_)"
<< "\n\t These sizes should match" << "\n\t These sizes should match"
<< "\n\t fe_.num_features(): " << fe_.num_features() << "\n\t fe_.num_features(): " << fe_.num_features()
<< "\n\t weights_.size(): " << weights_.size() << "\n\t weights_.size(): " << weights_.size()
......
...@@ -196,9 +196,22 @@ namespace dlib ...@@ -196,9 +196,22 @@ namespace dlib
); );
/*! /*!
ensures ensures
- #get_weights().size() == fe.num_features() - #get_feature_extractor() == feature_extractor()
(i.e. it will have its default value)
- #get_weights().size() == #get_feature_extractor().num_features()
- #get_weights() == 0 - #get_weights() == 0
- #get_feature_extractor() will have an initial value for its type !*/
explicit sequence_labeler(
const matrix<double,0,1>& weights
);
/*!
requires
- feature_extractor().num_features() == weights.size()
ensures
- #get_feature_extractor() == feature_extractor()
(i.e. it will have its default value)
- #get_weights() == weights
!*/ !*/
sequence_labeler( sequence_labeler(
......
...@@ -162,7 +162,7 @@ int main() ...@@ -162,7 +162,7 @@ int main()
matrix<double,0,1> true_hmm_model_weights = log(join_cols(reshape_to_column_vector(transition_probabilities), matrix<double,0,1> true_hmm_model_weights = log(join_cols(reshape_to_column_vector(transition_probabilities),
reshape_to_column_vector(emission_probabilities))); reshape_to_column_vector(emission_probabilities)));
sequence_labeler<feature_extractor> labeler_true(feature_extractor(), true_hmm_model_weights); sequence_labeler<feature_extractor> labeler_true(true_hmm_model_weights);
confusion_matrix = test_sequence_labeler(labeler_true, samples, labels); confusion_matrix = test_sequence_labeler(labeler_true, samples, labels);
cout << "\nTrue HMM model: " << endl; cout << "\nTrue HMM model: " << endl;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment