Commit fdc3af3a authored by Davis King's avatar Davis King

Refactoring and spec improvement. Still some work left to do though.

parent 7dc67b80
......@@ -25,7 +25,7 @@ namespace dlib
)
{
// make sure requires clause is not broken
DLIB_CASSERT( is_sequence_labeling_problem(samples, labels) == true,
DLIB_ASSERT( is_sequence_labeling_problem(samples, labels) == true,
"\tmatrix test_sequence_labeler()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t is_sequence_labeling_problem(samples, labels): "
......@@ -44,8 +44,8 @@ namespace dlib
const unsigned long truth = labels[i][j];
if (truth >= res.nr())
{
// make res big enough for this unexpected label
res = join_cols(res, zeros_matrix<double>(truth-res.nr()+1, res.nc()));
// ignore labels the labeler doesn't know about.
continue;
}
res(truth, pred[j]) += 1;
......@@ -69,7 +69,7 @@ namespace dlib
)
{
// make sure requires clause is not broken
DLIB_CASSERT(is_sequence_labeling_problem(samples,labels) == true &&
DLIB_ASSERT(is_sequence_labeling_problem(samples,labels) == true &&
1 < folds && folds <= static_cast<long>(samples.size()),
"\tmatrix cross_validate_sequence_labeler()"
<< "\n\t invalid inputs were given to this function"
......@@ -78,6 +78,25 @@ namespace dlib
<< "\n\t is_sequence_labeling_problem(samples,labels): " << is_sequence_labeling_problem(samples,labels)
);
#ifdef ENABLE_ASSERTS
for (unsigned long i = 0; i < labels.size(); ++i)
{
for (unsigned long j = 0; j < labels[i].size(); ++j)
{
// make sure requires clause is not broken
DLIB_ASSERT(labels[i][j] < trainer.num_labels(),
"\t matrix cross_validate_sequence_labeler()"
<< "\n\t The labels are invalid."
<< "\n\t labels[i][j]: " << labels[i][j]
<< "\n\t trainer.num_labels(): " << trainer.num_labels()
<< "\n\t i: " << i
<< "\n\t j: " << j
);
}
}
#endif
const long num_in_test = samples.size()/folds;
......@@ -117,36 +136,7 @@ namespace dlib
}
matrix<double> temp = test_sequence_labeler(trainer.train(x_train,y_train), x_test, y_test);
// Make sure res is always at least as big as temp. This might not be the case
// because temp is sized differently depending on how many different kinds of labels
// test_sequence_labeler() sees.
if (get_rect(res).contains(get_rect(temp)) == false)
{
if (res.size() == 0)
{
res.set_size(temp.nr(), temp.nc());
res = 0;
}
// Make res bigger by padding with zeros on the bottom or right if necessary.
if (res.nr() < temp.nr())
res = join_cols(res, zeros_matrix<double>(temp.nr()-res.nc(), res.nc()));
if (res.nc() < temp.nc())
res = join_rows(res, zeros_matrix<double>(res.nr(), temp.nc()-res.nc()));
}
// add temp to res
for (long r = 0; r < temp.nr(); ++r)
{
for (long c = 0; c < temp.nc(); ++c)
{
res(r,c) += temp(r,c);
}
}
res += test_sequence_labeler(trainer.train(x_train,y_train), x_test, y_test);
} // for (long i = 0; i < folds; ++i)
......
......@@ -25,14 +25,18 @@ namespace dlib
/*!
requires
- is_sequence_labeling_problem(samples, labels)
- sequence_labeler_type == dlib::sequence_labeler or an object with a
compatible interface.
ensures
- Tests labeler against the given samples and labels and returns a confusion
matrix summarizing the results.
- The confusion matrix C returned by this function has the following properties.
- C.nc() == labeler.num_labels()
- C.nr() == max(labeler.num_labels(), max value in labels + 1)
- C.nr() == labeler.num_labels()
- C(T,P) == the number of times a sample with label T was predicted
to have a label of P.
- Any samples with a label value >= labeler.num_labels() are ignored. That
is, samples with labels the labeler hasn't ever seen before are ignored.
!*/
// ----------------------------------------------------------------------------------------
......@@ -51,6 +55,22 @@ namespace dlib
requires
- is_sequence_labeling_problem(samples, labels)
- 1 < folds <= samples.size()
- for all valid i and j: labels[i][j] < trainer.num_labels()
- trainer_type == dlib::structural_sequence_labeling_trainer or an object
with a compatible interface.
ensures
- performs k-fold cross validation by using the given trainer to solve the
given sequence labeling problem for the given number of folds. Each fold
is tested using the output of the trainer and the confusion matrix from all
folds is summed and returned.
- The total confusion matrix is computed by running test_sequence_labeler()
on each fold and summing its output.
- The number of folds used is given by the folds argument.
- The confusion matrix C returned by this function has the following properties.
- C.nc() == trainer.num_labels()
- C.nr() == trainer.num_labels()
- C(T,P) == the number of times a sample with label T was predicted
to have a label of P.
!*/
// ----------------------------------------------------------------------------------------
......
......@@ -112,7 +112,10 @@ namespace dlib
public:
sequence_labeler()
{}
{
weights.set_size(fe.num_features());
weights = 0;
}
sequence_labeler(
const feature_extractor& fe_,
......@@ -120,7 +123,16 @@ namespace dlib
) :
fe(fe_),
weights(weights_)
{}
{
// make sure requires clause is not broken
DLIB_ASSERT(fe_.num_features() == weights_.size(),
"\t sequence_labeler::sequence_labeler()"
<< "\n\t These sizes should match"
<< "\n\t fe_.num_features(): " << fe_.num_features()
<< "\n\t weights_.size(): " << weights_.size()
<< "\n\t this: " << this
);
}
const feature_extractor& get_feature_extractor (
) const { return fe; }
......@@ -135,6 +147,13 @@ namespace dlib
const sample_sequence_type& x
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(num_labels() > 0,
"\t labeled_sequence_type sequence_labeler::operator()(x)"
<< "\n\t You can't have no labels."
<< "\n\t this: " << this
);
labeled_sequence_type y;
find_max_factor_graph_viterbi(map_prob(x,fe,weights), y);
return y;
......@@ -145,6 +164,13 @@ namespace dlib
labeled_sequence_type& y
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(num_labels() > 0,
"\t void sequence_labeler::label_sequence(x,y)"
<< "\n\t You can't have no labels."
<< "\n\t this: " << this
);
find_max_factor_graph_viterbi(map_prob(x,fe,weights), y);
}
......
......@@ -34,12 +34,49 @@ namespace dlib
structural_sequence_labeling_trainer (
) {}
const feature_extractor& get_feature_extractor (
) const { return fe; }
unsigned long num_labels (
) const { return fe.num_labels(); }
const sequence_labeler<feature_extractor> train(
const std::vector<sample_sequence_type>& x,
const std::vector<labeled_sequence_type>& y
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(is_sequence_labeling_problem(x,y) == true,
"\t sequence_labeler structural_sequence_labeling_trainer::train(x,y)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t x.size(): " << x.size()
<< "\n\t is_sequence_labeling_problem(x,y): " << is_sequence_labeling_problem(x,y)
);
#ifdef ENABLE_ASSERTS
for (unsigned long i = 0; i < y.size(); ++i)
{
for (unsigned long j = 0; j < y[i].size(); ++j)
{
// make sure requires clause is not broken
DLIB_ASSERT(y[i][j] < num_labels(),
"\t sequence_labeler structural_sequence_labeling_trainer::train(x,y)"
<< "\n\t The given labels in y are invalid."
<< "\n\t y[i][j]: " << y[i][j]
<< "\n\t num_labels(): " << num_labels()
<< "\n\t i: " << i
<< "\n\t j: " << j
<< "\n\t this: " << this
);
}
}
#endif
structural_svm_sequence_labeling_problem<feature_extractor> prob(x, y, fe);
oca solver;
matrix<double,0,1> weights;
......
......@@ -37,10 +37,40 @@ namespace dlib
structural_sequence_labeling_trainer (
) {}
const feature_extractor& get_feature_extractor (
) const { return fe; }
/*!
ensures
- returns the feature extractor used by this object
!*/
unsigned long num_labels (
) const { return fe.num_labels(); }
/*!
ensures
- returns get_feature_extractor().num_labels()
(i.e. returns the number of possible output labels for each
element of a sequence)
!*/
const sequence_labeler<feature_extractor> train(
const std::vector<sample_sequence_type>& x,
const std::vector<labeled_sequence_type>& y
) const;
/*!
requires
- is_sequence_labeling_problem(x, y)
- for all valid i and j: y[i][j] < num_labels()
ensures
- Uses the structural_svm_sequence_labeling_problem to train a
sequence_labeler on the given x/y training pairs. The idea is
to learn to predict a y given an input x.
- returns a function F with the following properties:
- F(new_x) == A sequence of predicted labels for the elements of new_x.
- F(new_x).size() == new_x.size()
- for all valid i:
- F(new_x)[i] == the predicted label of new_x[i]
!*/
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment