Commit d1cf19fc authored by Davis King's avatar Davis King

Added an option to learn just non-negative weights.

parent 1efcfb3d
......@@ -44,6 +44,22 @@ namespace dlib
feature_extractor() {}
feature_extractor(const ss_feature_extractor& ss_fe_) : fe(ss_fe_) {}
unsigned long num_nonnegative_weights (
) const
{
const unsigned long NL = ss_feature_extractor::use_BIO_model ? 3 : 5;
if (ss_feature_extractor::allow_negative_weights)
{
return 0;
}
else
{
// We make everything non-negative except for the label transition
// features.
return num_features() - NL*NL;
}
}
friend void serialize(const feature_extractor& item, std::ostream& out)
{
serialize(item.fe, out);
......@@ -181,12 +197,7 @@ namespace dlib
unsigned long position
) const
{
// Pull out an indicator feature for the type of transition between the
// previous label and the current label.
if (y.size() > 1)
set_feature(y(1)*num_labels() + y(0));
unsigned long offset = num_labels()*num_labels();
unsigned long offset = 0;
const int window_size = fe.window_size();
......@@ -214,6 +225,10 @@ namespace dlib
offset += num_labels()*base_dims;
}
// Pull out an indicator feature for the type of transition between the
// previous label and the current label.
if (y.size() > 1)
set_feature(offset + y(1)*num_labels() + y(0));
}
};
......
......@@ -162,8 +162,10 @@ namespace dlib
have been similarly defined except that there would be 5*5+5 slots for
the various label combination instead of 3*3+3.
Finally, while not shown here, we also include nine indicator features
in XI() to model label transitions.
Finally, while not shown here, we also include indicator features in
XI() to model label transitions. These are 9 extra features in the
case of the BIO tagging model and 25 extra in the case of the BILOU
tagging model.
THREAD SAFETY
Instances of this object are required to be threadsafe, that is, it should
......@@ -188,6 +190,13 @@ namespace dlib
// the previous label.
const static bool use_high_order_features = true;
// You use a tool like the structural_sequence_segmentation_trainer to learn the
// weight vector needed by a sequence_segmenter. You can tell the trainer to force
// all the elements of the weight vector corresponding to ZI() to be non-negative.
// This is all the elements of w except for the elements corresponding to the label
// transition indicator features. To do this, just set allow_negative_weights to false.
const static bool allow_negative_weights = true;
example_feature_extractor (
);
......
......@@ -20,13 +20,14 @@ namespace
dlib::rand rnd;
template <bool use_BIO_model_, bool use_high_order_features_>
template <bool use_BIO_model_, bool use_high_order_features_, bool allow_negative_weights_>
class unigram_extractor
{
public:
const static bool use_BIO_model = use_BIO_model_;
const static bool use_high_order_features = use_high_order_features_;
const static bool allow_negative_weights = allow_negative_weights_;
typedef std::vector<unsigned long> sequence_type;
......@@ -38,6 +39,12 @@ namespace
v1 = randm(num_features(), 1, rnd);
v2 = randm(num_features(), 1, rnd);
v3 = randm(num_features(), 1, rnd);
v1(0) = 1;
v2(1) = 1;
v3(2) = 1;
v1(3) = -1;
v2(4) = -1;
v3(5) = -1;
for (unsigned long i = 0; i < num_features(); ++i)
{
if ( i < 3)
......@@ -68,14 +75,14 @@ namespace
};
template <bool use_BIO_model_, bool use_high_order_features_>
void serialize(const unigram_extractor<use_BIO_model_,use_high_order_features_>& item , std::ostream& out )
template <bool use_BIO_model_, bool use_high_order_features_, bool neg>
void serialize(const unigram_extractor<use_BIO_model_,use_high_order_features_,neg>& item , std::ostream& out )
{
serialize(item.feats, out);
}
template <bool use_BIO_model_, bool use_high_order_features_>
void deserialize(unigram_extractor<use_BIO_model_,use_high_order_features_>& item, std::istream& in)
template <bool use_BIO_model_, bool use_high_order_features_, bool neg>
void deserialize(unigram_extractor<use_BIO_model_,use_high_order_features_,neg>& item, std::istream& in)
{
deserialize(item.feats, in);
}
......@@ -95,7 +102,7 @@ namespace
labels.resize(dataset_size);
unigram_extractor<true,true> fe;
unigram_extractor<true,true,true> fe;
dlib::rand rnd;
for (unsigned long iter = 0; iter < dataset_size; ++iter)
......@@ -167,23 +174,24 @@ namespace
// ----------------------------------------------------------------------------------------
template <bool use_BIO_model, bool use_high_order_features>
template <bool use_BIO_model, bool use_high_order_features, bool allow_negative_weights>
void do_test()
{
dlog << LINFO << "use_BIO_model: "<< use_BIO_model;
dlog << LINFO << "use_high_order_features: "<< use_high_order_features;
dlog << LINFO << "allow_negative_weights: "<< allow_negative_weights;
std::vector<std::vector<unsigned long> > samples;
std::vector<std::vector<std::pair<unsigned long,unsigned long> > > segments;
make_dataset2( samples, segments, 200);
make_dataset2( samples, segments, 100);
print_spinner();
typedef unigram_extractor<use_BIO_model,use_high_order_features> fe_type;
typedef unigram_extractor<use_BIO_model,use_high_order_features,allow_negative_weights> fe_type;
fe_type fe_temp;
fe_type fe_temp2;
structural_sequence_segmentation_trainer<fe_type> trainer(fe_temp2);
trainer.set_c(4);
trainer.set_c(5);
trainer.set_num_threads(1);
......@@ -214,9 +222,9 @@ namespace
matrix<double> res;
res = cross_validate_sequence_segmenter(trainer, samples, segments, 3);
DLIB_TEST(min(res) > 0.98);
dlog << LINFO << "cv res: "<< res;
make_dataset2( samples, segments, 300);
DLIB_TEST(min(res) > 0.98);
make_dataset2( samples, segments, 100);
res = test_sequence_segmenter(labeler, samples, segments);
dlog << LINFO << "test res: "<< res;
DLIB_TEST(min(res) > 0.98);
......@@ -232,6 +240,26 @@ namespace
res = test_sequence_segmenter(labeler2, samples, segments);
dlog << LINFO << "test res2: "<< res;
DLIB_TEST(min(res) > 0.98);
long N;
if (use_BIO_model)
N = 3*3;
else
N = 5*5;
const double min_normal_weight = min(colm(labeler2.get_weights(), 0, labeler2.get_weights().size()-N));
const double min_trans_weight = min(labeler2.get_weights());
dlog << LINFO << "min_normal_weight: " << min_normal_weight;
dlog << LINFO << "min_trans_weight: " << min_trans_weight;
if (allow_negative_weights)
{
DLIB_TEST(min_normal_weight < 0);
DLIB_TEST(min_trans_weight < 0);
}
else
{
DLIB_TEST(min_normal_weight == 0);
DLIB_TEST(min_trans_weight < 0);
}
}
// ----------------------------------------------------------------------------------------
......@@ -249,10 +277,14 @@ namespace
void perform_test (
)
{
do_test<true,true>();
do_test<true,false>();
do_test<false,true>();
do_test<false,false>();
do_test<true,true,false>();
do_test<true,false,false>();
do_test<false,true,false>();
do_test<false,false,false>();
do_test<true,true,true>();
do_test<true,false,true>();
do_test<false,true,true>();
do_test<false,false,true>();
}
} a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment