Commit 44c79bcb authored by Davis King's avatar Davis King

Relaxed the requirements on the feature extractor interface and also

added some tests to make sure the code really does work with the
relaxed interface.
parent 6d48b166
......@@ -137,7 +137,10 @@ namespace dlib
set_feature(55,1);
Therefore, the first argument to set_feature is the index of the feature
to be set while the second argument is the value the feature should take.
- This function only calls set_feature() once for each feature index.
Additionally, note that calling set_feature() multiple times with the same
feature index does NOT overwrite the old value, it adds to the previous
value. For example, if you call set_feature(55) 3 times then it will
result in feature 55 having a value of 3.
- This function only calls set_feature() with feature_index values < num_features()
!*/
......
......@@ -62,6 +62,47 @@ namespace
}
};
class feature_extractor_partial
{
public:
typedef unsigned long sample_type;
unsigned long num_features() const
{
return num_label_states*num_label_states + num_label_states*num_sample_states;
}
unsigned long order() const
{
return 1;
}
unsigned long num_labels() const
{
return num_label_states;
}
template <typename feature_setter, typename EXP>
void get_features (
feature_setter& set_feature,
const std::vector<sample_type>& x,
const matrix_exp<EXP>& y,
unsigned long position
) const
{
if (y.size() > 1)
{
set_feature(y(1)*num_label_states + y(0), 0.5);
set_feature(y(1)*num_label_states + y(0), 0.5);
}
set_feature(num_label_states*num_label_states +
y(0)*num_sample_states + x[position],0.4);
set_feature(num_label_states*num_label_states +
y(0)*num_sample_states + x[position],0.6);
}
};
bool called_rejct_labeling = false;
class feature_extractor2
{
......@@ -324,6 +365,53 @@ namespace
DLIB_TEST(std::abs(accuracy - 0.882) < 0.01);
}
// ----------------------------------------------------------------------------------------
void test2()
{
/*
The point of this test is to make sure calling set_feature() multiple
times works the way it is supposed to.
*/
print_spinner();
std::vector<std::vector<unsigned long> > samples;
std::vector<std::vector<unsigned long> > labels;
matrix<double> transition_probabilities(num_label_states, num_label_states);
transition_probabilities = 0.05, 0.90, 0.05,
0.05, 0.05, 0.90,
0.90, 0.05, 0.05;
matrix<double> emission_probabilities(num_label_states,num_sample_states);
emission_probabilities = 0.5, 0.5, 0.0,
0.0, 0.5, 0.5,
0.5, 0.0, 0.5;
make_dataset(transition_probabilities,emission_probabilities,
samples, labels, 1000);
dlog << LINFO << "samples.size(): "<< samples.size();
structural_sequence_labeling_trainer<feature_extractor> trainer;
structural_sequence_labeling_trainer<feature_extractor_partial> trainer_part;
trainer.set_c(4);
trainer_part.set_c(4);
trainer.set_num_threads(4);
trainer_part.set_num_threads(4);
// Learn to do sequence labeling from the dataset
sequence_labeler<feature_extractor> labeler = trainer.train(samples, labels);
sequence_labeler<feature_extractor_partial> labeler_part = trainer_part.train(samples, labels);
// Both feature extractors should be equivalent.
DLIB_TEST(length(labeler.get_weights() - labeler_part.get_weights()) < 1e-10);
}
// ----------------------------------------------------------------------------------------
class sequence_labeler_tester : public tester
......@@ -342,6 +430,8 @@ namespace
DLIB_TEST(called_rejct_labeling == false);
do_test<feature_extractor2>();
DLIB_TEST(called_rejct_labeling == true);
test2();
}
} a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment