Commit fdc3af3a authored by Davis King's avatar Davis King

Refactoring and spec improvement. Still some work left to do though.

parent 7dc67b80
...@@ -25,7 +25,7 @@ namespace dlib ...@@ -25,7 +25,7 @@ namespace dlib
) )
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_CASSERT( is_sequence_labeling_problem(samples, labels) == true, DLIB_ASSERT( is_sequence_labeling_problem(samples, labels) == true,
"\tmatrix test_sequence_labeler()" "\tmatrix test_sequence_labeler()"
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
<< "\n\t is_sequence_labeling_problem(samples, labels): " << "\n\t is_sequence_labeling_problem(samples, labels): "
...@@ -44,8 +44,8 @@ namespace dlib ...@@ -44,8 +44,8 @@ namespace dlib
const unsigned long truth = labels[i][j]; const unsigned long truth = labels[i][j];
if (truth >= res.nr()) if (truth >= res.nr())
{ {
// make res big enough for this unexpected label // ignore labels the labeler doesn't know about.
res = join_cols(res, zeros_matrix<double>(truth-res.nr()+1, res.nc())); continue;
} }
res(truth, pred[j]) += 1; res(truth, pred[j]) += 1;
...@@ -69,7 +69,7 @@ namespace dlib ...@@ -69,7 +69,7 @@ namespace dlib
) )
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_CASSERT(is_sequence_labeling_problem(samples,labels) == true && DLIB_ASSERT(is_sequence_labeling_problem(samples,labels) == true &&
1 < folds && folds <= static_cast<long>(samples.size()), 1 < folds && folds <= static_cast<long>(samples.size()),
"\tmatrix cross_validate_sequence_labeler()" "\tmatrix cross_validate_sequence_labeler()"
<< "\n\t invalid inputs were given to this function" << "\n\t invalid inputs were given to this function"
...@@ -78,6 +78,25 @@ namespace dlib ...@@ -78,6 +78,25 @@ namespace dlib
<< "\n\t is_sequence_labeling_problem(samples,labels): " << is_sequence_labeling_problem(samples,labels) << "\n\t is_sequence_labeling_problem(samples,labels): " << is_sequence_labeling_problem(samples,labels)
); );
#ifdef ENABLE_ASSERTS
for (unsigned long i = 0; i < labels.size(); ++i)
{
for (unsigned long j = 0; j < labels[i].size(); ++j)
{
// make sure requires clause is not broken
DLIB_ASSERT(labels[i][j] < trainer.num_labels(),
"\t matrix cross_validate_sequence_labeler()"
<< "\n\t The labels are invalid."
<< "\n\t labels[i][j]: " << labels[i][j]
<< "\n\t trainer.num_labels(): " << trainer.num_labels()
<< "\n\t i: " << i
<< "\n\t j: " << j
);
}
}
#endif
const long num_in_test = samples.size()/folds; const long num_in_test = samples.size()/folds;
...@@ -117,36 +136,7 @@ namespace dlib ...@@ -117,36 +136,7 @@ namespace dlib
} }
matrix<double> temp = test_sequence_labeler(trainer.train(x_train,y_train), x_test, y_test); res += test_sequence_labeler(trainer.train(x_train,y_train), x_test, y_test);
// Make sure res is always at least as big as temp. This might not be the case
// because temp is sized differently depending on how many different kinds of labels
// test_sequence_labeler() sees.
if (get_rect(res).contains(get_rect(temp)) == false)
{
if (res.size() == 0)
{
res.set_size(temp.nr(), temp.nc());
res = 0;
}
// Make res bigger by padding with zeros on the bottom or right if necessary.
if (res.nr() < temp.nr())
res = join_cols(res, zeros_matrix<double>(temp.nr()-res.nc(), res.nc()));
if (res.nc() < temp.nc())
res = join_rows(res, zeros_matrix<double>(res.nr(), temp.nc()-res.nc()));
}
// add temp to res
for (long r = 0; r < temp.nr(); ++r)
{
for (long c = 0; c < temp.nc(); ++c)
{
res(r,c) += temp(r,c);
}
}
} // for (long i = 0; i < folds; ++i) } // for (long i = 0; i < folds; ++i)
......
...@@ -25,14 +25,18 @@ namespace dlib ...@@ -25,14 +25,18 @@ namespace dlib
/*! /*!
requires requires
- is_sequence_labeling_problem(samples, labels) - is_sequence_labeling_problem(samples, labels)
- sequence_labeler_type == dlib::sequence_labeler or an object with a
compatible interface.
ensures ensures
- Tests labeler against the given samples and labels and returns a confusion - Tests labeler against the given samples and labels and returns a confusion
matrix summarizing the results. matrix summarizing the results.
- The confusion matrix C returned by this function has the following properties. - The confusion matrix C returned by this function has the following properties.
- C.nc() == labeler.num_labels() - C.nc() == labeler.num_labels()
- C.nr() == max(labeler.num_labels(), max value in labels + 1) - C.nr() == labeler.num_labels()
- C(T,P) == the number of times a sample with label T was predicted - C(T,P) == the number of times a sample with label T was predicted
to have a label of P. to have a label of P.
- Any samples with a label value >= labeler.num_labels() are ignored. That
is, samples with labels the labeler hasn't ever seen before are ignored.
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -51,6 +55,22 @@ namespace dlib ...@@ -51,6 +55,22 @@ namespace dlib
requires requires
- is_sequence_labeling_problem(samples, labels) - is_sequence_labeling_problem(samples, labels)
- 1 < folds <= samples.size() - 1 < folds <= samples.size()
- for all valid i and j: labels[i][j] < trainer.num_labels()
- trainer_type == dlib::structural_sequence_labeling_trainer or an object
with a compatible interface.
ensures
- performs k-fold cross validation by using the given trainer to solve the
given sequence labeling problem for the given number of folds. Each fold
is tested using the output of the trainer and the confusion matrix from all
folds is summed and returned.
- The total confusion matrix is computed by running test_sequence_labeler()
on each fold and summing its output.
- The number of folds used is given by the folds argument.
- The confusion matrix C returned by this function has the following properties.
- C.nc() == trainer.num_labels()
- C.nr() == trainer.num_labels()
- C(T,P) == the number of times a sample with label T was predicted
to have a label of P.
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -112,7 +112,10 @@ namespace dlib ...@@ -112,7 +112,10 @@ namespace dlib
public: public:
sequence_labeler() sequence_labeler()
{} {
weights.set_size(fe.num_features());
weights = 0;
}
sequence_labeler( sequence_labeler(
const feature_extractor& fe_, const feature_extractor& fe_,
...@@ -120,7 +123,16 @@ namespace dlib ...@@ -120,7 +123,16 @@ namespace dlib
) : ) :
fe(fe_), fe(fe_),
weights(weights_) weights(weights_)
{} {
// make sure requires clause is not broken
DLIB_ASSERT(fe_.num_features() == weights_.size(),
"\t sequence_labeler::sequence_labeler()"
<< "\n\t These sizes should match"
<< "\n\t fe_.num_features(): " << fe_.num_features()
<< "\n\t weights_.size(): " << weights_.size()
<< "\n\t this: " << this
);
}
const feature_extractor& get_feature_extractor ( const feature_extractor& get_feature_extractor (
) const { return fe; } ) const { return fe; }
...@@ -135,6 +147,13 @@ namespace dlib ...@@ -135,6 +147,13 @@ namespace dlib
const sample_sequence_type& x const sample_sequence_type& x
) const ) const
{ {
// make sure requires clause is not broken
DLIB_ASSERT(num_labels() > 0,
"\t labeled_sequence_type sequence_labeler::operator()(x)"
<< "\n\t You can't have no labels."
<< "\n\t this: " << this
);
labeled_sequence_type y; labeled_sequence_type y;
find_max_factor_graph_viterbi(map_prob(x,fe,weights), y); find_max_factor_graph_viterbi(map_prob(x,fe,weights), y);
return y; return y;
...@@ -145,6 +164,13 @@ namespace dlib ...@@ -145,6 +164,13 @@ namespace dlib
labeled_sequence_type& y labeled_sequence_type& y
) const ) const
{ {
// make sure requires clause is not broken
DLIB_ASSERT(num_labels() > 0,
"\t void sequence_labeler::label_sequence(x,y)"
<< "\n\t You can't have no labels."
<< "\n\t this: " << this
);
find_max_factor_graph_viterbi(map_prob(x,fe,weights), y); find_max_factor_graph_viterbi(map_prob(x,fe,weights), y);
} }
......
...@@ -34,12 +34,49 @@ namespace dlib ...@@ -34,12 +34,49 @@ namespace dlib
structural_sequence_labeling_trainer ( structural_sequence_labeling_trainer (
) {} ) {}
const feature_extractor& get_feature_extractor (
) const { return fe; }
unsigned long num_labels (
) const { return fe.num_labels(); }
const sequence_labeler<feature_extractor> train( const sequence_labeler<feature_extractor> train(
const std::vector<sample_sequence_type>& x, const std::vector<sample_sequence_type>& x,
const std::vector<labeled_sequence_type>& y const std::vector<labeled_sequence_type>& y
) const ) const
{ {
// make sure requires clause is not broken
DLIB_ASSERT(is_sequence_labeling_problem(x,y) == true,
"\t sequence_labeler structural_sequence_labeling_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.size(): " << x.size()
<< "\n\t is_sequence_labeling_problem(x,y): " << is_sequence_labeling_problem(x,y)
);
#ifdef ENABLE_ASSERTS
for (unsigned long i = 0; i < y.size(); ++i)
{
for (unsigned long j = 0; j < y[i].size(); ++j)
{
// make sure requires clause is not broken
DLIB_ASSERT(y[i][j] < num_labels(),
"\t sequence_labeler structural_sequence_labeling_trainer::train(x,y)"
<< "\n\t The given labels in y are invalid."
<< "\n\t y[i][j]: " << y[i][j]
<< "\n\t num_labels(): " << num_labels()
<< "\n\t i: " << i
<< "\n\t j: " << j
<< "\n\t this: " << this
);
}
}
#endif
structural_svm_sequence_labeling_problem<feature_extractor> prob(x, y, fe); structural_svm_sequence_labeling_problem<feature_extractor> prob(x, y, fe);
oca solver; oca solver;
matrix<double,0,1> weights; matrix<double,0,1> weights;
......
...@@ -37,10 +37,40 @@ namespace dlib ...@@ -37,10 +37,40 @@ namespace dlib
structural_sequence_labeling_trainer ( structural_sequence_labeling_trainer (
) {} ) {}
const feature_extractor& get_feature_extractor (
) const { return fe; }
/*!
ensures
- returns the feature extractor used by this object
!*/
unsigned long num_labels (
) const { return fe.num_labels(); }
/*!
ensures
- returns get_feature_extractor().num_labels()
(i.e. returns the number of possible output labels for each
element of a sequence)
!*/
const sequence_labeler<feature_extractor> train( const sequence_labeler<feature_extractor> train(
const std::vector<sample_sequence_type>& x, const std::vector<sample_sequence_type>& x,
const std::vector<labeled_sequence_type>& y const std::vector<labeled_sequence_type>& y
) const; ) const;
/*!
requires
- is_sequence_labeling_problem(x, y)
- for all valid i and j: y[i][j] < num_labels()
ensures
- Uses the structural_svm_sequence_labeling_problem to train a
sequence_labeler on the given x/y training pairs. The idea is
to learn to predict a y given an input x.
- returns a function F with the following properties:
- F(new_x) == A sequence of predicted labels for the elements of new_x.
- F(new_x).size() == new_x.size()
- for all valid i:
- F(new_x)[i] == the predicted label of new_x[i]
!*/
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment