Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
d1cf19fc
Commit
d1cf19fc
authored
May 15, 2013
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added an option to learn just non-negative weights.
parent
1efcfb3d
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
80 additions
and
24 deletions
+80
-24
sequence_segmenter.h
dlib/svm/sequence_segmenter.h
+21
-6
sequence_segmenter_abstract.h
dlib/svm/sequence_segmenter_abstract.h
+11
-2
sequence_segmenter.cpp
dlib/test/sequence_segmenter.cpp
+48
-16
No files found.
dlib/svm/sequence_segmenter.h
View file @
d1cf19fc
...
@@ -44,6 +44,22 @@ namespace dlib
...
@@ -44,6 +44,22 @@ namespace dlib
feature_extractor
()
{}
feature_extractor
()
{}
feature_extractor
(
const
ss_feature_extractor
&
ss_fe_
)
:
fe
(
ss_fe_
)
{}
feature_extractor
(
const
ss_feature_extractor
&
ss_fe_
)
:
fe
(
ss_fe_
)
{}
unsigned
long
num_nonnegative_weights
(
)
const
{
const
unsigned
long
NL
=
ss_feature_extractor
::
use_BIO_model
?
3
:
5
;
if
(
ss_feature_extractor
::
allow_negative_weights
)
{
return
0
;
}
else
{
// We make everything non-negative except for the label transition
// features.
return
num_features
()
-
NL
*
NL
;
}
}
friend
void
serialize
(
const
feature_extractor
&
item
,
std
::
ostream
&
out
)
friend
void
serialize
(
const
feature_extractor
&
item
,
std
::
ostream
&
out
)
{
{
serialize
(
item
.
fe
,
out
);
serialize
(
item
.
fe
,
out
);
...
@@ -181,12 +197,7 @@ namespace dlib
...
@@ -181,12 +197,7 @@ namespace dlib
unsigned
long
position
unsigned
long
position
)
const
)
const
{
{
// Pull out an indicator feature for the type of transition between the
unsigned
long
offset
=
0
;
// previous label and the current label.
if
(
y
.
size
()
>
1
)
set_feature
(
y
(
1
)
*
num_labels
()
+
y
(
0
));
unsigned
long
offset
=
num_labels
()
*
num_labels
();
const
int
window_size
=
fe
.
window_size
();
const
int
window_size
=
fe
.
window_size
();
...
@@ -214,6 +225,10 @@ namespace dlib
...
@@ -214,6 +225,10 @@ namespace dlib
offset
+=
num_labels
()
*
base_dims
;
offset
+=
num_labels
()
*
base_dims
;
}
}
// Pull out an indicator feature for the type of transition between the
// previous label and the current label.
if
(
y
.
size
()
>
1
)
set_feature
(
offset
+
y
(
1
)
*
num_labels
()
+
y
(
0
));
}
}
};
};
...
...
dlib/svm/sequence_segmenter_abstract.h
View file @
d1cf19fc
...
@@ -162,8 +162,10 @@ namespace dlib
...
@@ -162,8 +162,10 @@ namespace dlib
have been similarly defined except that there would be 5*5+5 slots for
have been similarly defined except that there would be 5*5+5 slots for
the various label combination instead of 3*3+3.
the various label combination instead of 3*3+3.
Finally, while not shown here, we also include nine indicator features
Finally, while not shown here, we also include indicator features in
in XI() to model label transitions.
XI() to model label transitions. These are 9 extra features in the
case of the BIO tagging model and 25 extra in the case of the BILOU
tagging model.
THREAD SAFETY
THREAD SAFETY
Instances of this object are required to be threadsafe, that is, it should
Instances of this object are required to be threadsafe, that is, it should
...
@@ -188,6 +190,13 @@ namespace dlib
...
@@ -188,6 +190,13 @@ namespace dlib
// the previous label.
// the previous label.
const
static
bool
use_high_order_features
=
true
;
const
static
bool
use_high_order_features
=
true
;
// You use a tool like the structural_sequence_segmentation_trainer to learn the
// weight vector needed by a sequence_segmenter. You can tell the trainer to force
// all the elements of the weight vector corresponding to ZI() to be non-negative.
// This is all the elements of w except for the elements corresponding to the label
// transition indicator features. To do this, just set allow_negative_weights to false.
const
static
bool
allow_negative_weights
=
true
;
example_feature_extractor
(
example_feature_extractor
(
);
);
...
...
dlib/test/sequence_segmenter.cpp
View file @
d1cf19fc
...
@@ -20,13 +20,14 @@ namespace
...
@@ -20,13 +20,14 @@ namespace
dlib
::
rand
rnd
;
dlib
::
rand
rnd
;
template
<
bool
use_BIO_model_
,
bool
use_high_order_features_
>
template
<
bool
use_BIO_model_
,
bool
use_high_order_features_
,
bool
allow_negative_weights_
>
class
unigram_extractor
class
unigram_extractor
{
{
public
:
public
:
const
static
bool
use_BIO_model
=
use_BIO_model_
;
const
static
bool
use_BIO_model
=
use_BIO_model_
;
const
static
bool
use_high_order_features
=
use_high_order_features_
;
const
static
bool
use_high_order_features
=
use_high_order_features_
;
const
static
bool
allow_negative_weights
=
allow_negative_weights_
;
typedef
std
::
vector
<
unsigned
long
>
sequence_type
;
typedef
std
::
vector
<
unsigned
long
>
sequence_type
;
...
@@ -38,6 +39,12 @@ namespace
...
@@ -38,6 +39,12 @@ namespace
v1
=
randm
(
num_features
(),
1
,
rnd
);
v1
=
randm
(
num_features
(),
1
,
rnd
);
v2
=
randm
(
num_features
(),
1
,
rnd
);
v2
=
randm
(
num_features
(),
1
,
rnd
);
v3
=
randm
(
num_features
(),
1
,
rnd
);
v3
=
randm
(
num_features
(),
1
,
rnd
);
v1
(
0
)
=
1
;
v2
(
1
)
=
1
;
v3
(
2
)
=
1
;
v1
(
3
)
=
-
1
;
v2
(
4
)
=
-
1
;
v3
(
5
)
=
-
1
;
for
(
unsigned
long
i
=
0
;
i
<
num_features
();
++
i
)
for
(
unsigned
long
i
=
0
;
i
<
num_features
();
++
i
)
{
{
if
(
i
<
3
)
if
(
i
<
3
)
...
@@ -68,14 +75,14 @@ namespace
...
@@ -68,14 +75,14 @@ namespace
};
};
template
<
bool
use_BIO_model_
,
bool
use_high_order_features_
>
template
<
bool
use_BIO_model_
,
bool
use_high_order_features_
,
bool
neg
>
void
serialize
(
const
unigram_extractor
<
use_BIO_model_
,
use_high_order_features_
>&
item
,
std
::
ostream
&
out
)
void
serialize
(
const
unigram_extractor
<
use_BIO_model_
,
use_high_order_features_
,
neg
>&
item
,
std
::
ostream
&
out
)
{
{
serialize
(
item
.
feats
,
out
);
serialize
(
item
.
feats
,
out
);
}
}
template
<
bool
use_BIO_model_
,
bool
use_high_order_features_
>
template
<
bool
use_BIO_model_
,
bool
use_high_order_features_
,
bool
neg
>
void
deserialize
(
unigram_extractor
<
use_BIO_model_
,
use_high_order_features_
>&
item
,
std
::
istream
&
in
)
void
deserialize
(
unigram_extractor
<
use_BIO_model_
,
use_high_order_features_
,
neg
>&
item
,
std
::
istream
&
in
)
{
{
deserialize
(
item
.
feats
,
in
);
deserialize
(
item
.
feats
,
in
);
}
}
...
@@ -95,7 +102,7 @@ namespace
...
@@ -95,7 +102,7 @@ namespace
labels
.
resize
(
dataset_size
);
labels
.
resize
(
dataset_size
);
unigram_extractor
<
true
,
true
>
fe
;
unigram_extractor
<
true
,
true
,
true
>
fe
;
dlib
::
rand
rnd
;
dlib
::
rand
rnd
;
for
(
unsigned
long
iter
=
0
;
iter
<
dataset_size
;
++
iter
)
for
(
unsigned
long
iter
=
0
;
iter
<
dataset_size
;
++
iter
)
...
@@ -167,23 +174,24 @@ namespace
...
@@ -167,23 +174,24 @@ namespace
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template
<
bool
use_BIO_model
,
bool
use_high_order_features
>
template
<
bool
use_BIO_model
,
bool
use_high_order_features
,
bool
allow_negative_weights
>
void
do_test
()
void
do_test
()
{
{
dlog
<<
LINFO
<<
"use_BIO_model: "
<<
use_BIO_model
;
dlog
<<
LINFO
<<
"use_BIO_model: "
<<
use_BIO_model
;
dlog
<<
LINFO
<<
"use_high_order_features: "
<<
use_high_order_features
;
dlog
<<
LINFO
<<
"use_high_order_features: "
<<
use_high_order_features
;
dlog
<<
LINFO
<<
"allow_negative_weights: "
<<
allow_negative_weights
;
std
::
vector
<
std
::
vector
<
unsigned
long
>
>
samples
;
std
::
vector
<
std
::
vector
<
unsigned
long
>
>
samples
;
std
::
vector
<
std
::
vector
<
std
::
pair
<
unsigned
long
,
unsigned
long
>
>
>
segments
;
std
::
vector
<
std
::
vector
<
std
::
pair
<
unsigned
long
,
unsigned
long
>
>
>
segments
;
make_dataset2
(
samples
,
segments
,
2
00
);
make_dataset2
(
samples
,
segments
,
1
00
);
print_spinner
();
print_spinner
();
typedef
unigram_extractor
<
use_BIO_model
,
use_high_order_features
>
fe_type
;
typedef
unigram_extractor
<
use_BIO_model
,
use_high_order_features
,
allow_negative_weights
>
fe_type
;
fe_type
fe_temp
;
fe_type
fe_temp
;
fe_type
fe_temp2
;
fe_type
fe_temp2
;
structural_sequence_segmentation_trainer
<
fe_type
>
trainer
(
fe_temp2
);
structural_sequence_segmentation_trainer
<
fe_type
>
trainer
(
fe_temp2
);
trainer
.
set_c
(
4
);
trainer
.
set_c
(
5
);
trainer
.
set_num_threads
(
1
);
trainer
.
set_num_threads
(
1
);
...
@@ -214,9 +222,9 @@ namespace
...
@@ -214,9 +222,9 @@ namespace
matrix
<
double
>
res
;
matrix
<
double
>
res
;
res
=
cross_validate_sequence_segmenter
(
trainer
,
samples
,
segments
,
3
);
res
=
cross_validate_sequence_segmenter
(
trainer
,
samples
,
segments
,
3
);
DLIB_TEST
(
min
(
res
)
>
0.98
);
dlog
<<
LINFO
<<
"cv res: "
<<
res
;
dlog
<<
LINFO
<<
"cv res: "
<<
res
;
make_dataset2
(
samples
,
segments
,
300
);
DLIB_TEST
(
min
(
res
)
>
0.98
);
make_dataset2
(
samples
,
segments
,
100
);
res
=
test_sequence_segmenter
(
labeler
,
samples
,
segments
);
res
=
test_sequence_segmenter
(
labeler
,
samples
,
segments
);
dlog
<<
LINFO
<<
"test res: "
<<
res
;
dlog
<<
LINFO
<<
"test res: "
<<
res
;
DLIB_TEST
(
min
(
res
)
>
0.98
);
DLIB_TEST
(
min
(
res
)
>
0.98
);
...
@@ -232,6 +240,26 @@ namespace
...
@@ -232,6 +240,26 @@ namespace
res
=
test_sequence_segmenter
(
labeler2
,
samples
,
segments
);
res
=
test_sequence_segmenter
(
labeler2
,
samples
,
segments
);
dlog
<<
LINFO
<<
"test res2: "
<<
res
;
dlog
<<
LINFO
<<
"test res2: "
<<
res
;
DLIB_TEST
(
min
(
res
)
>
0.98
);
DLIB_TEST
(
min
(
res
)
>
0.98
);
long
N
;
if
(
use_BIO_model
)
N
=
3
*
3
;
else
N
=
5
*
5
;
const
double
min_normal_weight
=
min
(
colm
(
labeler2
.
get_weights
(),
0
,
labeler2
.
get_weights
().
size
()
-
N
));
const
double
min_trans_weight
=
min
(
labeler2
.
get_weights
());
dlog
<<
LINFO
<<
"min_normal_weight: "
<<
min_normal_weight
;
dlog
<<
LINFO
<<
"min_trans_weight: "
<<
min_trans_weight
;
if
(
allow_negative_weights
)
{
DLIB_TEST
(
min_normal_weight
<
0
);
DLIB_TEST
(
min_trans_weight
<
0
);
}
else
{
DLIB_TEST
(
min_normal_weight
==
0
);
DLIB_TEST
(
min_trans_weight
<
0
);
}
}
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
...
@@ -249,10 +277,14 @@ namespace
...
@@ -249,10 +277,14 @@ namespace
void
perform_test
(
void
perform_test
(
)
)
{
{
do_test
<
true
,
true
>
();
do_test
<
true
,
true
,
false
>
();
do_test
<
true
,
false
>
();
do_test
<
true
,
false
,
false
>
();
do_test
<
false
,
true
>
();
do_test
<
false
,
true
,
false
>
();
do_test
<
false
,
false
>
();
do_test
<
false
,
false
,
false
>
();
do_test
<
true
,
true
,
true
>
();
do_test
<
true
,
false
,
true
>
();
do_test
<
false
,
true
,
true
>
();
do_test
<
false
,
false
,
true
>
();
}
}
}
a
;
}
a
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment