Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
4f08da0d
Commit
4f08da0d
authored
Nov 04, 2011
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added unit tests for the new sequence labeling stuff
parent
d933439a
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
292 additions
and
0 deletions
+292
-0
CMakeLists.txt
dlib/test/CMakeLists.txt
+1
-0
makefile
dlib/test/makefile
+1
-0
sequence_labeler.cpp
dlib/test/sequence_labeler.cpp
+290
-0
No files found.
dlib/test/CMakeLists.txt
View file @
4f08da0d
...
...
@@ -87,6 +87,7 @@ set (tests
reference_counter.cpp
scan_image.cpp
sequence.cpp
sequence_labeler.cpp
serialize.cpp
set.cpp
sldf.cpp
...
...
dlib/test/makefile
View file @
4f08da0d
...
...
@@ -102,6 +102,7 @@ SRC += read_write_mutex.cpp
SRC
+=
reference_counter.cpp
SRC
+=
scan_image.cpp
SRC
+=
sequence.cpp
SRC
+=
sequence_labeler.cpp
SRC
+=
serialize.cpp
SRC
+=
set.cpp
SRC
+=
sldf.cpp
...
...
dlib/test/sequence_labeler.cpp
0 → 100644
View file @
4f08da0d
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include "tester.h"
#include <dlib/svm_threaded.h>
#include <dlib/rand.h>
namespace
{
using
namespace
test
;
using
namespace
dlib
;
using
namespace
std
;
logger
dlog
(
"test.sequence_labeler"
);
// ----------------------------------------------------------------------------------------
const
unsigned
long
num_label_states
=
3
;
// the "hidden" states
const
unsigned
long
num_sample_states
=
3
;
// ----------------------------------------------------------------------------------------
class
feature_extractor
{
public
:
typedef
unsigned
long
sample_type
;
unsigned
long
num_features
()
const
{
return
num_label_states
*
num_label_states
+
num_label_states
*
num_sample_states
;
}
unsigned
long
order
()
const
{
return
1
;
}
unsigned
long
num_labels
()
const
{
return
num_label_states
;
}
template
<
typename
feature_setter
,
typename
EXP
>
void
get_features
(
feature_setter
&
set_feature
,
const
std
::
vector
<
sample_type
>&
x
,
const
matrix_exp
<
EXP
>&
y
,
unsigned
long
position
)
const
{
if
(
y
.
size
()
>
1
)
set_feature
(
y
(
1
)
*
num_label_states
+
y
(
0
));
set_feature
(
num_label_states
*
num_label_states
+
y
(
0
)
*
num_sample_states
+
x
[
position
]);
}
};
void
serialize
(
const
feature_extractor
&
,
std
::
ostream
&
)
{}
void
deserialize
(
feature_extractor
&
,
std
::
istream
&
)
{}
// ----------------------------------------------------------------------------------------
void
sample_hmm
(
dlib
::
rand
&
rnd
,
const
matrix
<
double
>&
transition_probabilities
,
const
matrix
<
double
>&
emission_probabilities
,
unsigned
long
previous_label
,
unsigned
long
&
next_label
,
unsigned
long
&
next_sample
)
/*!
requires
- previous_label < transition_probabilities.nr()
- transition_probabilities.nr() == transition_probabilities.nc()
- transition_probabilities.nr() == emission_probabilities.nr()
- The rows of transition_probabilities and emission_probabilities must sum to 1.
(i.e. sum_cols(transition_probabilities) and sum_cols(emission_probabilities)
must evaluate to vectors of all 1s.)
ensures
- This function randomly samples the HMM defined by transition_probabilities
and emission_probabilities assuming that the previous hidden state
was previous_label.
- The HMM is defined by:
- P(next_label |previous_label) == transition_probabilities(previous_label, next_label)
- P(next_sample|next_label) == emission_probabilities (next_label, next_sample)
- #next_label == the sampled value of the hidden state
- #next_sample == the sampled value of the observed state
!*/
{
// sample next_label
double
p
=
rnd
.
get_random_double
();
for
(
long
c
=
0
;
p
>=
0
&&
c
<
transition_probabilities
.
nc
();
++
c
)
{
next_label
=
c
;
p
-=
transition_probabilities
(
previous_label
,
c
);
}
// now sample next_sample
p
=
rnd
.
get_random_double
();
for
(
long
c
=
0
;
p
>=
0
&&
c
<
emission_probabilities
.
nc
();
++
c
)
{
next_sample
=
c
;
p
-=
emission_probabilities
(
next_label
,
c
);
}
}
// ----------------------------------------------------------------------------------------
void
make_dataset
(
const
matrix
<
double
>&
transition_probabilities
,
const
matrix
<
double
>&
emission_probabilities
,
std
::
vector
<
std
::
vector
<
unsigned
long
>
>&
samples
,
std
::
vector
<
std
::
vector
<
unsigned
long
>
>&
labels
,
unsigned
long
dataset_size
)
/*!
requires
- transition_probabilities.nr() == transition_probabilities.nc()
- transition_probabilities.nr() == emission_probabilities.nr()
- The rows of transition_probabilities and emission_probabilities must sum to 1.
(i.e. sum_cols(transition_probabilities) and sum_cols(emission_probabilities)
must evaluate to vectors of all 1s.)
ensures
- This function randomly samples a bunch of sequences from the HMM defined by
transition_probabilities and emission_probabilities.
- The HMM is defined by:
- The probability of transitioning from hidden state H1 to H2
is given by transition_probabilities(H1,H2).
- The probability of a hidden state H producing an observed state
O is given by emission_probabilities(H,O).
- #samples.size() == labels.size() == dataset_size
- for all valid i:
- #labels[i] is a randomly sampled sequence of hidden states from the
given HMM. #samples[i] is its corresponding randomly sampled sequence
of observed states.
!*/
{
samples
.
clear
();
labels
.
clear
();
dlib
::
rand
rnd
;
// now randomly sample some labeled sequences from our Hidden Markov Model
for
(
unsigned
long
iter
=
0
;
iter
<
dataset_size
;
++
iter
)
{
const
unsigned
long
sequence_size
=
rnd
.
get_random_32bit_number
()
%
20
+
3
;
std
::
vector
<
unsigned
long
>
sample
(
sequence_size
);
std
::
vector
<
unsigned
long
>
label
(
sequence_size
);
unsigned
long
previous_label
=
rnd
.
get_random_32bit_number
()
%
num_label_states
;
for
(
unsigned
long
i
=
0
;
i
<
sample
.
size
();
++
i
)
{
unsigned
long
next_label
=
0
,
next_sample
=
0
;
sample_hmm
(
rnd
,
transition_probabilities
,
emission_probabilities
,
previous_label
,
next_label
,
next_sample
);
label
[
i
]
=
next_label
;
sample
[
i
]
=
next_sample
;
previous_label
=
next_label
;
}
samples
.
push_back
(
sample
);
labels
.
push_back
(
label
);
}
}
// ----------------------------------------------------------------------------------------
class
sequence_labeler_tester
:
public
tester
{
public
:
sequence_labeler_tester
(
)
:
tester
(
"test_sequence_labeler"
,
"Runs tests on the sequence labeling code."
)
{}
void
perform_test
(
)
{
matrix
<
double
>
transition_probabilities
(
num_label_states
,
num_label_states
);
transition_probabilities
=
0.05
,
0.90
,
0.05
,
0.05
,
0.05
,
0.90
,
0.90
,
0.05
,
0.05
;
matrix
<
double
>
emission_probabilities
(
num_label_states
,
num_sample_states
);
emission_probabilities
=
0.5
,
0.5
,
0.0
,
0.0
,
0.5
,
0.5
,
0.5
,
0.0
,
0.5
;
print_spinner
();
std
::
vector
<
std
::
vector
<
unsigned
long
>
>
samples
;
std
::
vector
<
std
::
vector
<
unsigned
long
>
>
labels
;
make_dataset
(
transition_probabilities
,
emission_probabilities
,
samples
,
labels
,
1000
);
dlog
<<
LINFO
<<
"samples.size(): "
<<
samples
.
size
();
// print out some of the randomly sampled sequences
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
dlog
<<
LINFO
<<
"hidden states: "
<<
trans
(
vector_to_matrix
(
labels
[
i
]));
dlog
<<
LINFO
<<
"observed states: "
<<
trans
(
vector_to_matrix
(
samples
[
i
]));
dlog
<<
LINFO
<<
"******************************"
;
}
print_spinner
();
structural_sequence_labeling_trainer
<
feature_extractor
>
trainer
;
trainer
.
set_c
(
4
);
DLIB_TEST
(
trainer
.
get_c
()
==
4
);
trainer
.
set_num_threads
(
4
);
DLIB_TEST
(
trainer
.
get_num_threads
()
==
4
);
// Learn to do sequence labeling from the dataset
sequence_labeler
<
feature_extractor
>
labeler
=
trainer
.
train
(
samples
,
labels
);
std
::
vector
<
unsigned
long
>
predicted_labels
=
labeler
(
samples
[
0
]);
dlog
<<
LINFO
<<
"true hidden states: "
<<
trans
(
vector_to_matrix
(
labels
[
0
]));
dlog
<<
LINFO
<<
"predicted hidden states: "
<<
trans
(
vector_to_matrix
(
predicted_labels
));
DLIB_TEST
(
vector_to_matrix
(
labels
[
0
])
==
vector_to_matrix
(
predicted_labels
));
print_spinner
();
// We can also do cross-validation
matrix
<
double
>
confusion_matrix
;
confusion_matrix
=
cross_validate_sequence_labeler
(
trainer
,
samples
,
labels
,
4
);
dlog
<<
LINFO
<<
"cross-validation: "
;
dlog
<<
LINFO
<<
confusion_matrix
;
double
accuracy
=
sum
(
diag
(
confusion_matrix
))
/
sum
(
confusion_matrix
);
dlog
<<
LINFO
<<
"label accuracy: "
<<
accuracy
;
DLIB_TEST
(
std
::
abs
(
accuracy
-
0.882
)
<
0.01
);
print_spinner
();
matrix
<
double
,
0
,
1
>
true_hmm_model_weights
=
log
(
join_cols
(
reshape_to_column_vector
(
transition_probabilities
),
reshape_to_column_vector
(
emission_probabilities
)));
sequence_labeler
<
feature_extractor
>
labeler_true
(
true_hmm_model_weights
);
confusion_matrix
=
test_sequence_labeler
(
labeler_true
,
samples
,
labels
);
dlog
<<
LINFO
<<
"True HMM model: "
;
dlog
<<
LINFO
<<
confusion_matrix
;
accuracy
=
sum
(
diag
(
confusion_matrix
))
/
sum
(
confusion_matrix
);
dlog
<<
LINFO
<<
"label accuracy: "
<<
accuracy
;
DLIB_TEST
(
std
::
abs
(
accuracy
-
0.882
)
<
0.01
);
print_spinner
();
// Finally, the labeler can be serialized to disk just like most dlib objects.
ostringstream
sout
;
serialize
(
labeler
,
sout
);
sequence_labeler
<
feature_extractor
>
labeler2
;
// recall from disk
istringstream
sin
(
sout
.
str
());
deserialize
(
labeler2
,
sin
);
confusion_matrix
=
test_sequence_labeler
(
labeler2
,
samples
,
labels
);
dlog
<<
LINFO
<<
"deserialized labeler: "
;
dlog
<<
LINFO
<<
confusion_matrix
;
accuracy
=
sum
(
diag
(
confusion_matrix
))
/
sum
(
confusion_matrix
);
dlog
<<
LINFO
<<
"label accuracy: "
<<
accuracy
;
DLIB_TEST
(
std
::
abs
(
accuracy
-
0.882
)
<
0.01
);
}
}
a
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment