Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
fdc3af3a
Commit
fdc3af3a
authored
Nov 01, 2011
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactoring and spec improvement. Still some work left to do though.
parent
7dc67b80
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
140 additions
and
37 deletions
+140
-37
cross_validate_sequence_labeler.h
dlib/svm/cross_validate_sequence_labeler.h
+24
-34
cross_validate_sequence_labeler_abstract.h
dlib/svm/cross_validate_sequence_labeler_abstract.h
+21
-1
sequence_labeler.h
dlib/svm/sequence_labeler.h
+28
-2
structural_sequence_labeling_trainer.h
dlib/svm/structural_sequence_labeling_trainer.h
+37
-0
structural_sequence_labeling_trainer_abstract.h
dlib/svm/structural_sequence_labeling_trainer_abstract.h
+30
-0
No files found.
dlib/svm/cross_validate_sequence_labeler.h
View file @
fdc3af3a
...
...
@@ -25,7 +25,7 @@ namespace dlib
)
{
// make sure requires clause is not broken
DLIB_
C
ASSERT
(
is_sequence_labeling_problem
(
samples
,
labels
)
==
true
,
DLIB_ASSERT
(
is_sequence_labeling_problem
(
samples
,
labels
)
==
true
,
"
\t
matrix test_sequence_labeler()"
<<
"
\n\t
invalid inputs were given to this function"
<<
"
\n\t
is_sequence_labeling_problem(samples, labels): "
...
...
@@ -44,8 +44,8 @@ namespace dlib
const
unsigned
long
truth
=
labels
[
i
][
j
];
if
(
truth
>=
res
.
nr
())
{
//
make res big enough for this unexpected label
res
=
join_cols
(
res
,
zeros_matrix
<
double
>
(
truth
-
res
.
nr
()
+
1
,
res
.
nc
()))
;
//
ignore labels the labeler doesn't know about.
continue
;
}
res
(
truth
,
pred
[
j
])
+=
1
;
...
...
@@ -69,7 +69,7 @@ namespace dlib
)
{
// make sure requires clause is not broken
DLIB_
C
ASSERT
(
is_sequence_labeling_problem
(
samples
,
labels
)
==
true
&&
DLIB_ASSERT
(
is_sequence_labeling_problem
(
samples
,
labels
)
==
true
&&
1
<
folds
&&
folds
<=
static_cast
<
long
>
(
samples
.
size
()),
"
\t
matrix cross_validate_sequence_labeler()"
<<
"
\n\t
invalid inputs were given to this function"
...
...
@@ -78,6 +78,25 @@ namespace dlib
<<
"
\n\t
is_sequence_labeling_problem(samples,labels): "
<<
is_sequence_labeling_problem
(
samples
,
labels
)
);
#ifdef ENABLE_ASSERTS
for
(
unsigned
long
i
=
0
;
i
<
labels
.
size
();
++
i
)
{
for
(
unsigned
long
j
=
0
;
j
<
labels
[
i
].
size
();
++
j
)
{
// make sure requires clause is not broken
DLIB_ASSERT
(
labels
[
i
][
j
]
<
trainer
.
num_labels
(),
"
\t
matrix cross_validate_sequence_labeler()"
<<
"
\n\t
The labels are invalid."
<<
"
\n\t
labels[i][j]: "
<<
labels
[
i
][
j
]
<<
"
\n\t
trainer.num_labels(): "
<<
trainer
.
num_labels
()
<<
"
\n\t
i: "
<<
i
<<
"
\n\t
j: "
<<
j
);
}
}
#endif
const
long
num_in_test
=
samples
.
size
()
/
folds
;
...
...
@@ -117,36 +136,7 @@ namespace dlib
}
matrix
<
double
>
temp
=
test_sequence_labeler
(
trainer
.
train
(
x_train
,
y_train
),
x_test
,
y_test
);
// Make sure res is always at least as big as temp. This might not be the case
// because temp is sized differently depending on how many different kinds of labels
// test_sequence_labeler() sees.
if
(
get_rect
(
res
).
contains
(
get_rect
(
temp
))
==
false
)
{
if
(
res
.
size
()
==
0
)
{
res
.
set_size
(
temp
.
nr
(),
temp
.
nc
());
res
=
0
;
}
// Make res bigger by padding with zeros on the bottom or right if necessary.
if
(
res
.
nr
()
<
temp
.
nr
())
res
=
join_cols
(
res
,
zeros_matrix
<
double
>
(
temp
.
nr
()
-
res
.
nc
(),
res
.
nc
()));
if
(
res
.
nc
()
<
temp
.
nc
())
res
=
join_rows
(
res
,
zeros_matrix
<
double
>
(
res
.
nr
(),
temp
.
nc
()
-
res
.
nc
()));
}
// add temp to res
for
(
long
r
=
0
;
r
<
temp
.
nr
();
++
r
)
{
for
(
long
c
=
0
;
c
<
temp
.
nc
();
++
c
)
{
res
(
r
,
c
)
+=
temp
(
r
,
c
);
}
}
res
+=
test_sequence_labeler
(
trainer
.
train
(
x_train
,
y_train
),
x_test
,
y_test
);
}
// for (long i = 0; i < folds; ++i)
...
...
dlib/svm/cross_validate_sequence_labeler_abstract.h
View file @
fdc3af3a
...
...
@@ -25,14 +25,18 @@ namespace dlib
/*!
requires
- is_sequence_labeling_problem(samples, labels)
- sequence_labeler_type == dlib::sequence_labeler or an object with a
compatible interface.
ensures
- Tests labeler against the given samples and labels and returns a confusion
matrix summarizing the results.
- The confusion matrix C returned by this function has the following properties.
- C.nc() == labeler.num_labels()
- C.nr() ==
max(labeler.num_labels(), max value in labels + 1)
- C.nr() ==
labeler.num_labels()
- C(T,P) == the number of times a sample with label T was predicted
to have a label of P.
- Any samples with a label value >= labeler.num_labels() are ignored. That
is, samples with labels the labeler hasn't ever seen before are ignored.
!*/
// ----------------------------------------------------------------------------------------
...
...
@@ -51,6 +55,22 @@ namespace dlib
requires
- is_sequence_labeling_problem(samples, labels)
- 1 < folds <= samples.size()
- for all valid i and j: labels[i][j] < trainer.num_labels()
- trainer_type == dlib::structural_sequence_labeling_trainer or an object
with a compatible interface.
ensures
- performs k-fold cross validation by using the given trainer to solve the
given sequence labeling problem for the given number of folds. Each fold
is tested using the output of the trainer and the confusion matrix from all
folds is summed and returned.
- The total confusion matrix is computed by running test_sequence_labeler()
on each fold and summing its output.
- The number of folds used is given by the folds argument.
- The confusion matrix C returned by this function has the following properties.
- C.nc() == trainer.num_labels()
- C.nr() == trainer.num_labels()
- C(T,P) == the number of times a sample with label T was predicted
to have a label of P.
!*/
// ----------------------------------------------------------------------------------------
...
...
dlib/svm/sequence_labeler.h
View file @
fdc3af3a
...
...
@@ -112,7 +112,10 @@ namespace dlib
public
:
sequence_labeler
()
{}
{
weights
.
set_size
(
fe
.
num_features
());
weights
=
0
;
}
sequence_labeler
(
const
feature_extractor
&
fe_
,
...
...
@@ -120,7 +123,16 @@ namespace dlib
)
:
fe
(
fe_
),
weights
(
weights_
)
{}
{
// make sure requires clause is not broken
DLIB_ASSERT
(
fe_
.
num_features
()
==
weights_
.
size
(),
"
\t
sequence_labeler::sequence_labeler()"
<<
"
\n\t
These sizes should match"
<<
"
\n\t
fe_.num_features(): "
<<
fe_
.
num_features
()
<<
"
\n\t
weights_.size(): "
<<
weights_
.
size
()
<<
"
\n\t
this: "
<<
this
);
}
const
feature_extractor
&
get_feature_extractor
(
)
const
{
return
fe
;
}
...
...
@@ -135,6 +147,13 @@ namespace dlib
const
sample_sequence_type
&
x
)
const
{
// make sure requires clause is not broken
DLIB_ASSERT
(
num_labels
()
>
0
,
"
\t
labeled_sequence_type sequence_labeler::operator()(x)"
<<
"
\n\t
You can't have no labels."
<<
"
\n\t
this: "
<<
this
);
labeled_sequence_type
y
;
find_max_factor_graph_viterbi
(
map_prob
(
x
,
fe
,
weights
),
y
);
return
y
;
...
...
@@ -145,6 +164,13 @@ namespace dlib
labeled_sequence_type
&
y
)
const
{
// make sure requires clause is not broken
DLIB_ASSERT
(
num_labels
()
>
0
,
"
\t
void sequence_labeler::label_sequence(x,y)"
<<
"
\n\t
You can't have no labels."
<<
"
\n\t
this: "
<<
this
);
find_max_factor_graph_viterbi
(
map_prob
(
x
,
fe
,
weights
),
y
);
}
...
...
dlib/svm/structural_sequence_labeling_trainer.h
View file @
fdc3af3a
...
...
@@ -34,12 +34,49 @@ namespace dlib
structural_sequence_labeling_trainer
(
)
{}
const
feature_extractor
&
get_feature_extractor
(
)
const
{
return
fe
;
}
unsigned
long
num_labels
(
)
const
{
return
fe
.
num_labels
();
}
const
sequence_labeler
<
feature_extractor
>
train
(
const
std
::
vector
<
sample_sequence_type
>&
x
,
const
std
::
vector
<
labeled_sequence_type
>&
y
)
const
{
// make sure requires clause is not broken
DLIB_ASSERT
(
is_sequence_labeling_problem
(
x
,
y
)
==
true
,
"
\t
sequence_labeler structural_sequence_labeling_trainer::train(x,y)"
<<
"
\n\t
invalid inputs were given to this function"
<<
"
\n\t
x.size(): "
<<
x
.
size
()
<<
"
\n\t
is_sequence_labeling_problem(x,y): "
<<
is_sequence_labeling_problem
(
x
,
y
)
);
#ifdef ENABLE_ASSERTS
for
(
unsigned
long
i
=
0
;
i
<
y
.
size
();
++
i
)
{
for
(
unsigned
long
j
=
0
;
j
<
y
[
i
].
size
();
++
j
)
{
// make sure requires clause is not broken
DLIB_ASSERT
(
y
[
i
][
j
]
<
num_labels
(),
"
\t
sequence_labeler structural_sequence_labeling_trainer::train(x,y)"
<<
"
\n\t
The given labels in y are invalid."
<<
"
\n\t
y[i][j]: "
<<
y
[
i
][
j
]
<<
"
\n\t
num_labels(): "
<<
num_labels
()
<<
"
\n\t
i: "
<<
i
<<
"
\n\t
j: "
<<
j
<<
"
\n\t
this: "
<<
this
);
}
}
#endif
structural_svm_sequence_labeling_problem
<
feature_extractor
>
prob
(
x
,
y
,
fe
);
oca
solver
;
matrix
<
double
,
0
,
1
>
weights
;
...
...
dlib/svm/structural_sequence_labeling_trainer_abstract.h
View file @
fdc3af3a
...
...
@@ -37,10 +37,40 @@ namespace dlib
structural_sequence_labeling_trainer
(
)
{}
const
feature_extractor
&
get_feature_extractor
(
)
const
{
return
fe
;
}
/*!
ensures
- returns the feature extractor used by this object
!*/
unsigned
long
num_labels
(
)
const
{
return
fe
.
num_labels
();
}
/*!
ensures
- returns get_feature_extractor().num_labels()
(i.e. returns the number of possible output labels for each
element of a sequence)
!*/
const
sequence_labeler
<
feature_extractor
>
train
(
const
std
::
vector
<
sample_sequence_type
>&
x
,
const
std
::
vector
<
labeled_sequence_type
>&
y
)
const
;
/*!
requires
- is_sequence_labeling_problem(x, y)
- for all valid i and j: y[i][j] < num_labels()
ensures
- Uses the structural_svm_sequence_labeling_problem to train a
sequence_labeler on the given x/y training pairs. The idea is
to learn to predict a y given an input x.
- returns a function F with the following properties:
- F(new_x) == A sequence of predicted labels for the elements of new_x.
- F(new_x).size() == new_x.size()
- for all valid i:
- F(new_x)[i] == the predicted label of new_x[i]
!*/
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment