Commit c23ef609 authored by Davis King's avatar Davis King

added more tests

parent f18acdf8
...@@ -62,8 +62,58 @@ namespace ...@@ -62,8 +62,58 @@ namespace
} }
}; };
bool called_rejct_labeling = false;
class feature_extractor2
{
public:
typedef unsigned long sample_type;
unsigned long num_features() const
{
return num_label_states*num_label_states + num_label_states*num_sample_states;
}
unsigned long order() const
{
return 1;
}
unsigned long num_labels() const
{
return num_label_states;
}
template <typename EXP>
bool reject_labeling (
const std::vector<sample_type>& x,
const matrix_exp<EXP>& y,
unsigned long position
) const
{
called_rejct_labeling = true;
return false;
}
template <typename feature_setter, typename EXP>
void get_features (
feature_setter& set_feature,
const std::vector<sample_type>& x,
const matrix_exp<EXP>& y,
unsigned long position
) const
{
if (y.size() > 1)
set_feature(y(1)*num_label_states + y(0));
set_feature(num_label_states*num_label_states +
y(0)*num_sample_states + x[position]);
}
};
void serialize(const feature_extractor&, std::ostream&) {} void serialize(const feature_extractor&, std::ostream&) {}
void deserialize(feature_extractor&, std::istream&) {} void deserialize(feature_extractor&, std::istream&) {}
void serialize(const feature_extractor2&, std::ostream&) {}
void deserialize(feature_extractor2&, std::istream&) {}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -174,114 +224,124 @@ namespace ...@@ -174,114 +224,124 @@ namespace
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename fe_type>
class sequence_labeler_tester : public tester void do_test()
{ {
public: matrix<double> transition_probabilities(num_label_states, num_label_states);
sequence_labeler_tester ( transition_probabilities = 0.05, 0.90, 0.05,
) : 0.05, 0.05, 0.90,
tester ("test_sequence_labeler", 0.90, 0.05, 0.05;
"Runs tests on the sequence labeling code.")
{}
void perform_test ( matrix<double> emission_probabilities(num_label_states,num_sample_states);
) emission_probabilities = 0.5, 0.5, 0.0,
{ 0.0, 0.5, 0.5,
matrix<double> transition_probabilities(num_label_states, num_label_states); 0.5, 0.0, 0.5;
transition_probabilities = 0.05, 0.90, 0.05,
0.05, 0.05, 0.90, print_spinner();
0.90, 0.05, 0.05;
matrix<double> emission_probabilities(num_label_states,num_sample_states);
emission_probabilities = 0.5, 0.5, 0.0,
0.0, 0.5, 0.5,
0.5, 0.0, 0.5;
print_spinner(); std::vector<std::vector<unsigned long> > samples;
std::vector<std::vector<unsigned long> > labels;
make_dataset(transition_probabilities,emission_probabilities,
samples, labels, 1000);
dlog << LINFO << "samples.size(): "<< samples.size();
std::vector<std::vector<unsigned long> > samples; // print out some of the randomly sampled sequences
std::vector<std::vector<unsigned long> > labels; for (int i = 0; i < 10; ++i)
make_dataset(transition_probabilities,emission_probabilities, {
samples, labels, 1000); dlog << LINFO << "hidden states: " << trans(vector_to_matrix(labels[i]));
dlog << LINFO << "observed states: " << trans(vector_to_matrix(samples[i]));
dlog << LINFO << "******************************";
}
dlog << LINFO << "samples.size(): "<< samples.size(); print_spinner();
structural_sequence_labeling_trainer<fe_type> trainer;
trainer.set_c(4);
DLIB_TEST(trainer.get_c() == 4);
trainer.set_num_threads(4);
DLIB_TEST(trainer.get_num_threads() == 4);
// print out some of the randomly sampled sequences
for (int i = 0; i < 10; ++i)
{
dlog << LINFO << "hidden states: " << trans(vector_to_matrix(labels[i]));
dlog << LINFO << "observed states: " << trans(vector_to_matrix(samples[i]));
dlog << LINFO << "******************************";
}
print_spinner();
structural_sequence_labeling_trainer<feature_extractor> trainer;
trainer.set_c(4);
DLIB_TEST(trainer.get_c() == 4);
trainer.set_num_threads(4);
DLIB_TEST(trainer.get_num_threads() == 4);
// Learn to do sequence labeling from the dataset
sequence_labeler<fe_type> labeler = trainer.train(samples, labels);
std::vector<unsigned long> predicted_labels = labeler(samples[0]);
dlog << LINFO << "true hidden states: "<< trans(vector_to_matrix(labels[0]));
dlog << LINFO << "predicted hidden states: "<< trans(vector_to_matrix(predicted_labels));
// Learn to do sequence labeling from the dataset DLIB_TEST(vector_to_matrix(labels[0]) == vector_to_matrix(predicted_labels));
sequence_labeler<feature_extractor> labeler = trainer.train(samples, labels);
std::vector<unsigned long> predicted_labels = labeler(samples[0]);
dlog << LINFO << "true hidden states: "<< trans(vector_to_matrix(labels[0]));
dlog << LINFO << "predicted hidden states: "<< trans(vector_to_matrix(predicted_labels));
DLIB_TEST(vector_to_matrix(labels[0]) == vector_to_matrix(predicted_labels)); print_spinner();
print_spinner(); // We can also do cross-validation
matrix<double> confusion_matrix;
confusion_matrix = cross_validate_sequence_labeler(trainer, samples, labels, 4);
dlog << LINFO << "cross-validation: ";
dlog << LINFO << confusion_matrix;
double accuracy = sum(diag(confusion_matrix))/sum(confusion_matrix);
dlog << LINFO << "label accuracy: "<< accuracy;
DLIB_TEST(std::abs(accuracy - 0.882) < 0.01);
print_spinner();
// We can also do cross-validation
matrix<double> confusion_matrix;
confusion_matrix = cross_validate_sequence_labeler(trainer, samples, labels, 4);
dlog << LINFO << "cross-validation: ";
dlog << LINFO << confusion_matrix;
double accuracy = sum(diag(confusion_matrix))/sum(confusion_matrix);
dlog << LINFO << "label accuracy: "<< accuracy;
DLIB_TEST(std::abs(accuracy - 0.882) < 0.01);
print_spinner(); matrix<double,0,1> true_hmm_model_weights = log(join_cols(reshape_to_column_vector(transition_probabilities),
reshape_to_column_vector(emission_probabilities)));
sequence_labeler<fe_type> labeler_true(true_hmm_model_weights);
matrix<double,0,1> true_hmm_model_weights = log(join_cols(reshape_to_column_vector(transition_probabilities), confusion_matrix = test_sequence_labeler(labeler_true, samples, labels);
reshape_to_column_vector(emission_probabilities))); dlog << LINFO << "True HMM model: ";
dlog << LINFO << confusion_matrix;
accuracy = sum(diag(confusion_matrix))/sum(confusion_matrix);
dlog << LINFO << "label accuracy: "<< accuracy;
DLIB_TEST(std::abs(accuracy - 0.882) < 0.01);
sequence_labeler<feature_extractor> labeler_true(true_hmm_model_weights);
confusion_matrix = test_sequence_labeler(labeler_true, samples, labels);
dlog << LINFO << "True HMM model: ";
dlog << LINFO << confusion_matrix;
accuracy = sum(diag(confusion_matrix))/sum(confusion_matrix);
dlog << LINFO << "label accuracy: "<< accuracy;
DLIB_TEST(std::abs(accuracy - 0.882) < 0.01);
print_spinner();
print_spinner();
// Finally, the labeler can be serialized to disk just like most dlib objects.
ostringstream sout;
serialize(labeler, sout);
sequence_labeler<fe_type> labeler2;
// recall from disk
istringstream sin(sout.str());
deserialize(labeler2, sin);
confusion_matrix = test_sequence_labeler(labeler2, samples, labels);
dlog << LINFO << "deserialized labeler: ";
dlog << LINFO << confusion_matrix;
accuracy = sum(diag(confusion_matrix))/sum(confusion_matrix);
dlog << LINFO << "label accuracy: "<< accuracy;
DLIB_TEST(std::abs(accuracy - 0.882) < 0.01);
}
// Finally, the labeler can be serialized to disk just like most dlib objects. // ----------------------------------------------------------------------------------------
ostringstream sout;
serialize(labeler, sout); class sequence_labeler_tester : public tester
{
public:
sequence_labeler_tester (
) :
tester ("test_sequence_labeler",
"Runs tests on the sequence labeling code.")
{}
sequence_labeler<feature_extractor> labeler2; void perform_test (
// recall from disk )
istringstream sin(sout.str()); {
deserialize(labeler2, sin); do_test<feature_extractor>();
confusion_matrix = test_sequence_labeler(labeler2, samples, labels); DLIB_TEST(called_rejct_labeling == false);
dlog << LINFO << "deserialized labeler: "; do_test<feature_extractor2>();
dlog << LINFO << confusion_matrix; DLIB_TEST(called_rejct_labeling == true);
accuracy = sum(diag(confusion_matrix))/sum(confusion_matrix);
dlog << LINFO << "label accuracy: "<< accuracy;
DLIB_TEST(std::abs(accuracy - 0.882) < 0.01);
} }
} a; } a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment