Commit d1cf19fc authored by Davis King's avatar Davis King

Added an option to learn just non-negative weights.

parent 1efcfb3d
...@@ -44,6 +44,22 @@ namespace dlib ...@@ -44,6 +44,22 @@ namespace dlib
feature_extractor() {} feature_extractor() {}
feature_extractor(const ss_feature_extractor& ss_fe_) : fe(ss_fe_) {} feature_extractor(const ss_feature_extractor& ss_fe_) : fe(ss_fe_) {}
unsigned long num_nonnegative_weights (
) const
{
const unsigned long NL = ss_feature_extractor::use_BIO_model ? 3 : 5;
if (ss_feature_extractor::allow_negative_weights)
{
return 0;
}
else
{
// We make everything non-negative except for the label transition
// features.
return num_features() - NL*NL;
}
}
friend void serialize(const feature_extractor& item, std::ostream& out) friend void serialize(const feature_extractor& item, std::ostream& out)
{ {
serialize(item.fe, out); serialize(item.fe, out);
...@@ -181,12 +197,7 @@ namespace dlib ...@@ -181,12 +197,7 @@ namespace dlib
unsigned long position unsigned long position
) const ) const
{ {
// Pull out an indicator feature for the type of transition between the unsigned long offset = 0;
// previous label and the current label.
if (y.size() > 1)
set_feature(y(1)*num_labels() + y(0));
unsigned long offset = num_labels()*num_labels();
const int window_size = fe.window_size(); const int window_size = fe.window_size();
...@@ -214,6 +225,10 @@ namespace dlib ...@@ -214,6 +225,10 @@ namespace dlib
offset += num_labels()*base_dims; offset += num_labels()*base_dims;
} }
// Pull out an indicator feature for the type of transition between the
// previous label and the current label.
if (y.size() > 1)
set_feature(offset + y(1)*num_labels() + y(0));
} }
}; };
......
...@@ -162,8 +162,10 @@ namespace dlib ...@@ -162,8 +162,10 @@ namespace dlib
have been similarly defined except that there would be 5*5+5 slots for have been similarly defined except that there would be 5*5+5 slots for
the various label combination instead of 3*3+3. the various label combination instead of 3*3+3.
Finally, while not shown here, we also include nine indicator features Finally, while not shown here, we also include indicator features in
in XI() to model label transitions. XI() to model label transitions. These are 9 extra features in the
case of the BIO tagging model and 25 extra in the case of the BILOU
tagging model.
THREAD SAFETY THREAD SAFETY
Instances of this object are required to be threadsafe, that is, it should Instances of this object are required to be threadsafe, that is, it should
...@@ -188,6 +190,13 @@ namespace dlib ...@@ -188,6 +190,13 @@ namespace dlib
// the previous label. // the previous label.
const static bool use_high_order_features = true; const static bool use_high_order_features = true;
// You use a tool like the structural_sequence_segmentation_trainer to learn the
// weight vector needed by a sequence_segmenter. You can tell the trainer to force
// all the elements of the weight vector corresponding to ZI() to be non-negative.
// This is all the elements of w except for the elements corresponding to the label
// transition indicator features. To do this, just set allow_negative_weights to false.
const static bool allow_negative_weights = true;
example_feature_extractor ( example_feature_extractor (
); );
......
...@@ -20,13 +20,14 @@ namespace ...@@ -20,13 +20,14 @@ namespace
dlib::rand rnd; dlib::rand rnd;
template <bool use_BIO_model_, bool use_high_order_features_> template <bool use_BIO_model_, bool use_high_order_features_, bool allow_negative_weights_>
class unigram_extractor class unigram_extractor
{ {
public: public:
const static bool use_BIO_model = use_BIO_model_; const static bool use_BIO_model = use_BIO_model_;
const static bool use_high_order_features = use_high_order_features_; const static bool use_high_order_features = use_high_order_features_;
const static bool allow_negative_weights = allow_negative_weights_;
typedef std::vector<unsigned long> sequence_type; typedef std::vector<unsigned long> sequence_type;
...@@ -38,6 +39,12 @@ namespace ...@@ -38,6 +39,12 @@ namespace
v1 = randm(num_features(), 1, rnd); v1 = randm(num_features(), 1, rnd);
v2 = randm(num_features(), 1, rnd); v2 = randm(num_features(), 1, rnd);
v3 = randm(num_features(), 1, rnd); v3 = randm(num_features(), 1, rnd);
v1(0) = 1;
v2(1) = 1;
v3(2) = 1;
v1(3) = -1;
v2(4) = -1;
v3(5) = -1;
for (unsigned long i = 0; i < num_features(); ++i) for (unsigned long i = 0; i < num_features(); ++i)
{ {
if ( i < 3) if ( i < 3)
...@@ -68,14 +75,14 @@ namespace ...@@ -68,14 +75,14 @@ namespace
}; };
template <bool use_BIO_model_, bool use_high_order_features_> template <bool use_BIO_model_, bool use_high_order_features_, bool neg>
void serialize(const unigram_extractor<use_BIO_model_,use_high_order_features_>& item , std::ostream& out ) void serialize(const unigram_extractor<use_BIO_model_,use_high_order_features_,neg>& item , std::ostream& out )
{ {
serialize(item.feats, out); serialize(item.feats, out);
} }
template <bool use_BIO_model_, bool use_high_order_features_> template <bool use_BIO_model_, bool use_high_order_features_, bool neg>
void deserialize(unigram_extractor<use_BIO_model_,use_high_order_features_>& item, std::istream& in) void deserialize(unigram_extractor<use_BIO_model_,use_high_order_features_,neg>& item, std::istream& in)
{ {
deserialize(item.feats, in); deserialize(item.feats, in);
} }
...@@ -95,7 +102,7 @@ namespace ...@@ -95,7 +102,7 @@ namespace
labels.resize(dataset_size); labels.resize(dataset_size);
unigram_extractor<true,true> fe; unigram_extractor<true,true,true> fe;
dlib::rand rnd; dlib::rand rnd;
for (unsigned long iter = 0; iter < dataset_size; ++iter) for (unsigned long iter = 0; iter < dataset_size; ++iter)
...@@ -167,23 +174,24 @@ namespace ...@@ -167,23 +174,24 @@ namespace
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <bool use_BIO_model, bool use_high_order_features> template <bool use_BIO_model, bool use_high_order_features, bool allow_negative_weights>
void do_test() void do_test()
{ {
dlog << LINFO << "use_BIO_model: "<< use_BIO_model; dlog << LINFO << "use_BIO_model: "<< use_BIO_model;
dlog << LINFO << "use_high_order_features: "<< use_high_order_features; dlog << LINFO << "use_high_order_features: "<< use_high_order_features;
dlog << LINFO << "allow_negative_weights: "<< allow_negative_weights;
std::vector<std::vector<unsigned long> > samples; std::vector<std::vector<unsigned long> > samples;
std::vector<std::vector<std::pair<unsigned long,unsigned long> > > segments; std::vector<std::vector<std::pair<unsigned long,unsigned long> > > segments;
make_dataset2( samples, segments, 200); make_dataset2( samples, segments, 100);
print_spinner(); print_spinner();
typedef unigram_extractor<use_BIO_model,use_high_order_features> fe_type; typedef unigram_extractor<use_BIO_model,use_high_order_features,allow_negative_weights> fe_type;
fe_type fe_temp; fe_type fe_temp;
fe_type fe_temp2; fe_type fe_temp2;
structural_sequence_segmentation_trainer<fe_type> trainer(fe_temp2); structural_sequence_segmentation_trainer<fe_type> trainer(fe_temp2);
trainer.set_c(4); trainer.set_c(5);
trainer.set_num_threads(1); trainer.set_num_threads(1);
...@@ -214,9 +222,9 @@ namespace ...@@ -214,9 +222,9 @@ namespace
matrix<double> res; matrix<double> res;
res = cross_validate_sequence_segmenter(trainer, samples, segments, 3); res = cross_validate_sequence_segmenter(trainer, samples, segments, 3);
DLIB_TEST(min(res) > 0.98);
dlog << LINFO << "cv res: "<< res; dlog << LINFO << "cv res: "<< res;
make_dataset2( samples, segments, 300); DLIB_TEST(min(res) > 0.98);
make_dataset2( samples, segments, 100);
res = test_sequence_segmenter(labeler, samples, segments); res = test_sequence_segmenter(labeler, samples, segments);
dlog << LINFO << "test res: "<< res; dlog << LINFO << "test res: "<< res;
DLIB_TEST(min(res) > 0.98); DLIB_TEST(min(res) > 0.98);
...@@ -232,6 +240,26 @@ namespace ...@@ -232,6 +240,26 @@ namespace
res = test_sequence_segmenter(labeler2, samples, segments); res = test_sequence_segmenter(labeler2, samples, segments);
dlog << LINFO << "test res2: "<< res; dlog << LINFO << "test res2: "<< res;
DLIB_TEST(min(res) > 0.98); DLIB_TEST(min(res) > 0.98);
long N;
if (use_BIO_model)
N = 3*3;
else
N = 5*5;
const double min_normal_weight = min(colm(labeler2.get_weights(), 0, labeler2.get_weights().size()-N));
const double min_trans_weight = min(labeler2.get_weights());
dlog << LINFO << "min_normal_weight: " << min_normal_weight;
dlog << LINFO << "min_trans_weight: " << min_trans_weight;
if (allow_negative_weights)
{
DLIB_TEST(min_normal_weight < 0);
DLIB_TEST(min_trans_weight < 0);
}
else
{
DLIB_TEST(min_normal_weight == 0);
DLIB_TEST(min_trans_weight < 0);
}
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -249,10 +277,14 @@ namespace ...@@ -249,10 +277,14 @@ namespace
void perform_test ( void perform_test (
) )
{ {
do_test<true,true>(); do_test<true,true,false>();
do_test<true,false>(); do_test<true,false,false>();
do_test<false,true>(); do_test<false,true,false>();
do_test<false,false>(); do_test<false,false,false>();
do_test<true,true,true>();
do_test<true,false,true>();
do_test<false,true,true>();
do_test<false,false,true>();
} }
} a; } a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment