Commit 1efcfb3d authored by Davis King's avatar Davis King

Made the sequence_segmenter work with both BIO and BILOU tagging models.

parent 6d623eeb
...@@ -16,9 +16,13 @@ namespace dlib ...@@ -16,9 +16,13 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// BIO/BILOU labels
const unsigned int BEGIN = 0; const unsigned int BEGIN = 0;
const unsigned int INSIDE = 1; const unsigned int INSIDE = 1;
const unsigned int OUTSIDE = 2; const unsigned int OUTSIDE = 2;
const unsigned int LAST = 3;
const unsigned int UNIT = 4;
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
...@@ -52,13 +56,11 @@ namespace dlib ...@@ -52,13 +56,11 @@ namespace dlib
unsigned long num_features() const unsigned long num_features() const
{ {
const unsigned long NL = ss_feature_extractor::use_BIO_model ? 3 : 5;
const int base_dims = fe.num_features(); if (ss_feature_extractor::use_high_order_features)
return num_labels()*( return NL*NL + (NL*NL+NL)*fe.num_features()*fe.window_size();
num_labels() + // previous and current label else
base_dims*fe.window_size() + // window around current element return NL*NL + NL*fe.num_features()*fe.window_size();
num_labels()*base_dims*fe.window_size() // window around current element in conjunction with previous label
);
} }
unsigned long order() const unsigned long order() const
...@@ -68,7 +70,10 @@ namespace dlib ...@@ -68,7 +70,10 @@ namespace dlib
unsigned long num_labels() const unsigned long num_labels() const
{ {
return 3; if (ss_feature_extractor::use_BIO_model)
return 3;
else
return 5;
} }
private: private:
...@@ -113,10 +118,58 @@ namespace dlib ...@@ -113,10 +118,58 @@ namespace dlib
unsigned long unsigned long
) const ) const
{ {
// Don't allow BIO label patterns that don't correspond to a sensical if (ss_feature_extractor::use_BIO_model)
// segmentation. {
if (y.size() > 1 && y(0) == INSIDE && y(1) == OUTSIDE) // Don't allow BIO label patterns that don't correspond to a sensical
return true; // segmentation.
if (y.size() > 1 && y(0) == INSIDE && y(1) == OUTSIDE)
return true;
if (y.size() == 1 && y(0) == INSIDE)
return true;
}
else
{
// Don't allow BILOU label patterns that don't correspond to a sensical
// segmentation.
if (y.size() > 1)
{
if (y(1) == BEGIN && y(0) == OUTSIDE)
return true;
if (y(1) == BEGIN && y(0) == UNIT)
return true;
if (y(1) == BEGIN && y(0) == BEGIN)
return true;
if (y(1) == INSIDE && y(0) == BEGIN)
return true;
if (y(1) == INSIDE && y(0) == OUTSIDE)
return true;
if (y(1) == INSIDE && y(0) == UNIT)
return true;
if (y(1) == OUTSIDE && y(0) == INSIDE)
return true;
if (y(1) == OUTSIDE && y(0) == LAST)
return true;
if (y(1) == LAST && y(0) == INSIDE)
return true;
if (y(1) == LAST && y(0) == LAST)
return true;
if (y(1) == UNIT && y(0) == INSIDE)
return true;
if (y(1) == UNIT && y(0) == LAST)
return true;
}
else
{
if (y(0) == INSIDE)
return true;
if (y(0) == LAST)
return true;
}
}
return false; return false;
} }
...@@ -146,7 +199,8 @@ namespace dlib ...@@ -146,7 +199,8 @@ namespace dlib
const unsigned long off1 = y(0)*base_dims; const unsigned long off1 = y(0)*base_dims;
dot_functor<feature_setter> fs1(set_feature, offset+off1); dot_functor<feature_setter> fs1(set_feature, offset+off1);
fe.get_features(fs1, x, pos); fe.get_features(fs1, x, pos);
if (y.size() > 1)
if (ss_feature_extractor::use_high_order_features && y.size() > 1)
{ {
const unsigned long off2 = num_labels()*base_dims + (y(0)*num_labels()+y(1))*base_dims; const unsigned long off2 = num_labels()*base_dims + (y(0)*num_labels()+y(1))*base_dims;
dot_functor<feature_setter> fs2(set_feature, offset+off2); dot_functor<feature_setter> fs2(set_feature, offset+off2);
...@@ -154,7 +208,10 @@ namespace dlib ...@@ -154,7 +208,10 @@ namespace dlib
} }
} }
offset += num_labels()*(base_dims + num_labels()*base_dims); if (ss_feature_extractor::use_high_order_features)
offset += num_labels()*base_dims + num_labels()*num_labels()*base_dims;
else
offset += num_labels()*base_dims;
} }
} }
...@@ -171,7 +228,11 @@ namespace dlib ...@@ -171,7 +228,11 @@ namespace dlib
const feature_extractor& fe const feature_extractor& fe
) )
{ {
return 3*3 + 12*fe.num_features()*fe.window_size(); const unsigned long NL = feature_extractor::use_BIO_model ? 3 : 5;
if (feature_extractor::use_high_order_features)
return NL*NL + (NL*NL+NL)*fe.num_features()*fe.window_size();
else
return NL*NL + NL*fe.num_features()*fe.window_size();
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -272,18 +333,41 @@ namespace dlib ...@@ -272,18 +333,41 @@ namespace dlib
std::vector<unsigned long> labels; std::vector<unsigned long> labels;
labeler.label_sequence(x, labels); labeler.label_sequence(x, labels);
// Convert from BIO tagging to the explicit segments representation. if (feature_extractor::use_BIO_model)
for (unsigned long i = 0; i < labels.size(); ++i)
{ {
if (labels[i] == impl_ss::BEGIN) // Convert from BIO tagging to the explicit segments representation.
for (unsigned long i = 0; i < labels.size(); ++i)
{ {
const unsigned long begin = i; if (labels[i] == impl_ss::BEGIN)
++i; {
while (i < labels.size() && labels[i] == impl_ss::INSIDE) const unsigned long begin = i;
++i;
while (i < labels.size() && labels[i] == impl_ss::INSIDE)
++i;
y.push_back(std::make_pair(begin, i));
--i;
}
}
}
else
{
// Convert from BILOU tagging to the explicit segments representation.
for (unsigned long i = 0; i < labels.size(); ++i)
{
if (labels[i] == impl_ss::BEGIN)
{
const unsigned long begin = i;
++i; ++i;
while (i < labels.size() && labels[i] == impl_ss::INSIDE)
++i;
y.push_back(std::make_pair(begin, i)); y.push_back(std::make_pair(begin, i+1));
--i; }
else if (labels[i] == impl_ss::UNIT)
{
y.push_back(std::make_pair(i, i+1));
}
} }
} }
} }
......
...@@ -26,25 +26,30 @@ namespace dlib ...@@ -26,25 +26,30 @@ namespace dlib
Where w is a parameter vector and the label sequence defines a segmentation Where w is a parameter vector and the label sequence defines a segmentation
of x. of x.
Recall that a sequence_segmenter uses the BIO tagging model and is also an Recall that a sequence_segmenter uses the BIO or BILOU tagging model and is
instantiation of the dlib::sequence_labeler. This means that each element also an instantiation of the dlib::sequence_labeler. Selecting to use the
of the label sequence y takes on one of three possible values (B, I, or O) BIO model means that each element of the label sequence y takes on one of
and together these labels define a segmentation of the sequence. For example, three possible values (B, I, or O) and together these labels define a
to represent a segmentation of the sequence of words "The dog ran to Bob Jones" segmentation of the sequence. For example, to represent a segmentation of
where only "Bob Jones" was segmented out we would use the label sequence OOOOBI. the sequence of words "The dog ran to Bob Jones" where only "Bob Jones" was
segmented out we would use the label sequence OOOOBI. The BILOU model is
similar except that it uses five different labels and each segment is
labeled as U, BL, BIL, BIIL, BIIIL, and so on depending on its length.
Therefore, the BILOU model is able to more explicitly model the ends of the
segments than the BIO model, but has more parameters to estimate.
Keeping this in mind, the purpose of a sequence_segmenter is to take care Keeping all this in mind, the purpose of a sequence_segmenter is to take
of the bookkeeping associated with creating BIO tagging models for care of the bookkeeping associated with creating BIO/BILOU tagging models
segmentation tasks. In particular, it presents the user with a simplified for segmentation tasks. In particular, it presents the user with a
version of the interface used by the dlib::sequence_labeler. It does this simplified version of the interface used by the dlib::sequence_labeler. It
by completely hiding the BIO tags from the user and instead exposes an does this by completely hiding the BIO/BILOU tags from the user and instead
explicit sub-segment based labeling representation. It also simplifies the exposes an explicit sub-segment based labeling representation. It also
construction of the PSI() feature vector. simplifies the construction of the PSI() feature vector.
Like in the dlib::sequence_labeler, PSI() is a sum of feature vectors, each Like in the dlib::sequence_labeler, PSI() is a sum of feature vectors, each
derived from the entire input sequence x but only part of the label derived from the entire input sequence x but only part of the label
sequence y. In the case of the sequence_segmenter, we use an order one sequence y. In the case of the sequence_segmenter, we use an order one
model Markov. This means that that model Markov. This means that
PSI(x,y) == sum_i XI(x, y_{i-1}, y_{i}, i) PSI(x,y) == sum_i XI(x, y_{i-1}, y_{i}, i)
where the sum is taken over all the elements in the sequence. At each where the sum is taken over all the elements in the sequence. At each
element we extract a feature vector, XI(), that is expected to encode element we extract a feature vector, XI(), that is expected to encode
...@@ -61,12 +66,12 @@ namespace dlib ...@@ -61,12 +66,12 @@ namespace dlib
independent of any labeling. We denote this feature vector by ZI(x,i), where independent of any labeling. We denote this feature vector by ZI(x,i), where
x is the sequence and i is the position in question. x is the sequence and i is the position in question.
For example, suppose we use a window size of 3, then we can put all this For example, suppose we use a window size of 3 and BIO tags, then we can
together and define XI() in terms of ZI(). To do this, we can think of put all this together and define XI() in terms of ZI(). To do this, we can
XI() as containing 12*3 slots which contain either a zero vector or a ZI() think of XI() as containing 12*3 slots which contain either a zero vector
vector. Each combination of window position and labeling has a different or a ZI() vector. Each combination of window position and labeling has a
slot. To explain further, consider the following examples where we have different slot. To explain further, consider the following examples where
annotated which parts of XI() correspond to each slot. we have annotated which parts of XI() correspond to each slot.
If the previous and current label are both B and we use a window size of 3 If the previous and current label are both B and we use a window size of 3
then XI() would be instantiated as: then XI() would be instantiated as:
...@@ -152,7 +157,10 @@ namespace dlib ...@@ -152,7 +157,10 @@ namespace dlib
0 \ 0 \
0 > If previous label is O and current label is O 0 > If previous label is O and current label is O
0] / 0] /
If we had instead used the BILOU tagging model the XI() vector would
have been similarly defined except that there would be 5*5+5 slots for
the various label combination instead of 3*3+3.
Finally, while not shown here, we also include nine indicator features Finally, while not shown here, we also include nine indicator features
in XI() to model label transitions. in XI() to model label transitions.
...@@ -168,6 +176,19 @@ namespace dlib ...@@ -168,6 +176,19 @@ namespace dlib
// anything so long as it has a .size() which returns the length of the sequence. // anything so long as it has a .size() which returns the length of the sequence.
typedef the_type_used_to_represent_a_sequence sequence_type; typedef the_type_used_to_represent_a_sequence sequence_type;
// If you want to use the BIO tagging model then set this bool to true. Set it to
// false to use the BILOU tagging model.
const static bool use_BIO_model = true;
// In the WHAT THIS OBJECT REPRESENTS section above we discussed how we model the
// conjunction of the previous label and the window around each position. Doing
// this greatly expands the size of the parameter vector w. You can optionally
// disable these higher order features by setting the use_high_order_features bool
// to false. This will cause XI() to include only slots which are independent of
// the previous label.
const static bool use_high_order_features = true;
example_feature_extractor ( example_feature_extractor (
); );
/*! /*!
...@@ -257,9 +278,8 @@ namespace dlib ...@@ -257,9 +278,8 @@ namespace dlib
- fe must be an object that implements an interface compatible with the - fe must be an object that implements an interface compatible with the
example_feature_extractor discussed above. example_feature_extractor discussed above.
ensures ensures
- returns 3*3 + 12*fe.num_features()*fe.window_size() - returns the dimensionality of the PSI() vector defined by the given feature
(i.e. returns the dimensionality of the PSI() vector defined by the given extractor.
feature extractor.
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -283,10 +303,11 @@ namespace dlib ...@@ -283,10 +303,11 @@ namespace dlib
contiguous words which refer to proper names. contiguous words which refer to proper names.
The sequence_segmenter is implemented using the BIO (Begin, Inside, The sequence_segmenter is implemented using the BIO (Begin, Inside,
Outside) sequence tagging model. Moreover, the sequence tagging is done Outside) or BILOU (Begin, Inside, Last, Outside, Unit) sequence tagging
internally using a dlib::sequence_labeler object and therefore model. Moreover, the sequence tagging is done internally using a
sequence_segmenter objects are examples of chain structured conditional dlib::sequence_labeler object and therefore sequence_segmenter objects are
random field style sequence taggers. examples of chain structured conditional random field style sequence
taggers.
THREAD SAFETY THREAD SAFETY
It is always safe to use distinct instances of this object in different It is always safe to use distinct instances of this object in different
......
...@@ -142,20 +142,50 @@ namespace dlib ...@@ -142,20 +142,50 @@ namespace dlib
<< "\n\t this: " << this << "\n\t this: " << this
); );
// convert y into tagged BIO labels
std::vector<std::vector<unsigned long> > labels(y.size()); std::vector<std::vector<unsigned long> > labels(y.size());
for (unsigned long i = 0; i < labels.size(); ++i) if (feature_extractor::use_BIO_model)
{ {
labels[i].resize(x[i].size(), impl_ss::OUTSIDE); // convert y into tagged BIO labels
for (unsigned long j = 0; j < y[i].size(); ++j) for (unsigned long i = 0; i < labels.size(); ++i)
{ {
const unsigned long begin = y[i][j].first; labels[i].resize(x[i].size(), impl_ss::OUTSIDE);
const unsigned long end = y[i][j].second; for (unsigned long j = 0; j < y[i].size(); ++j)
if (begin != end)
{ {
labels[i][begin] = impl_ss::BEGIN; const unsigned long begin = y[i][j].first;
for (unsigned long k = begin+1; k < end; ++k) const unsigned long end = y[i][j].second;
labels[i][k] = impl_ss::INSIDE; if (begin != end)
{
labels[i][begin] = impl_ss::BEGIN;
for (unsigned long k = begin+1; k < end; ++k)
labels[i][k] = impl_ss::INSIDE;
}
}
}
}
else
{
// convert y into tagged BILOU labels
for (unsigned long i = 0; i < labels.size(); ++i)
{
labels[i].resize(x[i].size(), impl_ss::OUTSIDE);
for (unsigned long j = 0; j < y[i].size(); ++j)
{
const unsigned long begin = y[i][j].first;
const unsigned long end = y[i][j].second;
if (begin != end)
{
if (begin+1==end)
{
labels[i][begin] = impl_ss::UNIT;
}
else
{
labels[i][begin] = impl_ss::BEGIN;
for (unsigned long k = begin+1; k+1 < end; ++k)
labels[i][k] = impl_ss::INSIDE;
labels[i][end-1] = impl_ss::LAST;
}
}
} }
} }
} }
......
...@@ -20,10 +20,14 @@ namespace ...@@ -20,10 +20,14 @@ namespace
dlib::rand rnd; dlib::rand rnd;
template <bool use_BIO_model_, bool use_high_order_features_>
class unigram_extractor class unigram_extractor
{ {
public: public:
const static bool use_BIO_model = use_BIO_model_;
const static bool use_high_order_features = use_high_order_features_;
typedef std::vector<unsigned long> sequence_type; typedef std::vector<unsigned long> sequence_type;
std::map<unsigned long, matrix<double,0,1> > feats; std::map<unsigned long, matrix<double,0,1> > feats;
...@@ -64,12 +68,14 @@ namespace ...@@ -64,12 +68,14 @@ namespace
}; };
void serialize(const unigram_extractor& item , std::ostream& out ) template <bool use_BIO_model_, bool use_high_order_features_>
void serialize(const unigram_extractor<use_BIO_model_,use_high_order_features_>& item , std::ostream& out )
{ {
serialize(item.feats, out); serialize(item.feats, out);
} }
void deserialize(unigram_extractor& item, std::istream& in) template <bool use_BIO_model_, bool use_high_order_features_>
void deserialize(unigram_extractor<use_BIO_model_,use_high_order_features_>& item, std::istream& in)
{ {
deserialize(item.feats, in); deserialize(item.feats, in);
} }
...@@ -89,7 +95,7 @@ namespace ...@@ -89,7 +95,7 @@ namespace
labels.resize(dataset_size); labels.resize(dataset_size);
unigram_extractor fe; unigram_extractor<true,true> fe;
dlib::rand rnd; dlib::rand rnd;
for (unsigned long iter = 0; iter < dataset_size; ++iter) for (unsigned long iter = 0; iter < dataset_size; ++iter)
...@@ -161,22 +167,27 @@ namespace ...@@ -161,22 +167,27 @@ namespace
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <bool use_BIO_model, bool use_high_order_features>
void do_test() void do_test()
{ {
dlog << LINFO << "use_BIO_model: "<< use_BIO_model;
dlog << LINFO << "use_high_order_features: "<< use_high_order_features;
std::vector<std::vector<unsigned long> > samples; std::vector<std::vector<unsigned long> > samples;
std::vector<std::vector<std::pair<unsigned long,unsigned long> > > segments; std::vector<std::vector<std::pair<unsigned long,unsigned long> > > segments;
make_dataset2( samples, segments, 200); make_dataset2( samples, segments, 200);
print_spinner(); print_spinner();
typedef unigram_extractor<use_BIO_model,use_high_order_features> fe_type;
unigram_extractor fe_temp; fe_type fe_temp;
unigram_extractor fe_temp2; fe_type fe_temp2;
structural_sequence_segmentation_trainer<unigram_extractor> trainer(fe_temp2); structural_sequence_segmentation_trainer<fe_type> trainer(fe_temp2);
trainer.set_c(4); trainer.set_c(4);
trainer.set_num_threads(1); trainer.set_num_threads(1);
sequence_segmenter<unigram_extractor> labeler = trainer.train(samples, segments); sequence_segmenter<fe_type> labeler = trainer.train(samples, segments);
print_spinner(); print_spinner();
...@@ -215,7 +226,7 @@ namespace ...@@ -215,7 +226,7 @@ namespace
ostringstream sout; ostringstream sout;
serialize(labeler, sout); serialize(labeler, sout);
istringstream sin(sout.str()); istringstream sin(sout.str());
sequence_segmenter<unigram_extractor> labeler2; sequence_segmenter<fe_type> labeler2;
deserialize(labeler2, sin); deserialize(labeler2, sin);
res = test_sequence_segmenter(labeler2, samples, segments); res = test_sequence_segmenter(labeler2, samples, segments);
...@@ -238,7 +249,10 @@ namespace ...@@ -238,7 +249,10 @@ namespace
void perform_test ( void perform_test (
) )
{ {
do_test(); do_test<true,true>();
do_test<true,false>();
do_test<false,true>();
do_test<false,false>();
} }
} a; } a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment