Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
D
dlib
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
钟尚武
dlib
Commits
0246088a
Commit
0246088a
authored
Jul 29, 2012
by
Davis King
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Added a per node loss interface for the structural_graph_labeling_trainer.
parent
7aab9f71
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
17 deletions
+70
-17
structural_graph_labeling_trainer.h
dlib/svm/structural_graph_labeling_trainer.h
+30
-12
structural_graph_labeling_trainer_abstract.h
dlib/svm/structural_graph_labeling_trainer_abstract.h
+40
-5
No files found.
dlib/svm/structural_graph_labeling_trainer.h
View file @
0246088a
...
...
@@ -167,20 +167,23 @@ namespace dlib
>
const
graph_labeler
<
vector_type
>
train
(
const
dlib
::
array
<
graph_type
>&
samples
,
const
std
::
vector
<
label_type
>&
labels
const
std
::
vector
<
label_type
>&
labels
,
const
std
::
vector
<
std
::
vector
<
double
>
>&
losses
)
const
{
// make sure requires clause is not broken
DLIB_ASSERT
(
is_graph_labeling_problem
(
samples
,
labels
),
"
\t
void structural_graph_labeling_trainer::train()"
<<
"
\n\t
Invalid inputs were given to this function."
<<
"
\n\t
samples.size(): "
<<
samples
.
size
()
<<
"
\n\t
labels.size(): "
<<
labels
.
size
()
<<
"
\n\t
this: "
<<
this
);
DLIB_ASSERT
(
is_graph_labeling_problem
(
samples
,
labels
)
==
true
&&
(
losses
.
size
()
==
0
||
sizes_match
(
labels
,
losses
)
==
true
)
&&
all_values_are_nonnegative
(
losses
)
==
true
,
"
\t
void structural_graph_labeling_trainer::train()"
<<
"
\n\t
Invalid inputs were given to this function."
<<
"
\n\t
samples.size(): "
<<
samples
.
size
()
<<
"
\n\t
labels.size(): "
<<
labels
.
size
()
<<
"
\n\t
losses.size(): "
<<
losses
.
size
()
<<
"
\n\t
sizes_match(labels,losses): "
<<
sizes_match
(
labels
,
losses
)
<<
"
\n\t
all_values_are_nonnegative(losses): "
<<
all_values_are_nonnegative
(
losses
)
<<
"
\n\t
this: "
<<
this
);
std
::
vector
<
std
::
vector
<
double
>
>
losses
;
structural_svm_graph_labeling_problem
<
graph_type
>
prob
(
samples
,
labels
,
losses
,
num_threads
);
if
(
verbose
)
...
...
@@ -189,8 +192,11 @@ namespace dlib
prob
.
set_c
(
C
);
prob
.
set_epsilon
(
eps
);
prob
.
set_max_cache_size
(
max_cache_size
);
prob
.
set_loss_on_positive_class
(
loss_pos
);
prob
.
set_loss_on_negative_class
(
loss_neg
);
if
(
prob
.
get_losses
().
size
()
==
0
)
{
prob
.
set_loss_on_positive_class
(
loss_pos
);
prob
.
set_loss_on_negative_class
(
loss_neg
);
}
matrix
<
double
,
0
,
1
>
w
;
solver
(
prob
,
w
,
prob
.
get_num_edge_weights
());
...
...
@@ -201,6 +207,18 @@ namespace dlib
return
graph_labeler
<
vector_type
>
(
edge_weights
,
node_weights
);
}
template
<
typename
graph_type
>
const
graph_labeler
<
vector_type
>
train
(
const
dlib
::
array
<
graph_type
>&
samples
,
const
std
::
vector
<
label_type
>&
labels
)
const
{
std
::
vector
<
std
::
vector
<
double
>
>
losses
;
return
train
(
samples
,
labels
,
losses
);
}
private
:
template
<
typename
T
>
...
...
dlib/svm/structural_graph_labeling_trainer_abstract.h
View file @
0246088a
...
...
@@ -212,14 +212,49 @@ namespace dlib
requires
- is_graph_labeling_problem(samples,labels) == true
ensures
- Uses the structural_svm_graph_labeling_problem to train a
graph_labeler on the given samples/labels training pairs.
The idea is to learn to predict a label given an input sample.
- Uses the structural_svm_graph_labeling_problem to train a graph_labeler
on the given samples/labels training pairs. The idea is to learn to
predict a label given an input sample.
- The values of get_loss_on_positive_class() and get_loss_on_negative_class()
are used to determine how to value mistakes on each node during training.
- returns a function F with the following properties:
- F(new_sample) == The predicted labels for the nodes in the
graph
new_sample.
- F(new_sample) == The predicted labels for the nodes in the
graph
new_sample.
!*/
template
<
typename
graph_type
>
const
graph_labeler
<
vector_type
>
train
(
const
dlib
::
array
<
graph_type
>&
samples
,
const
std
::
vector
<
label_type
>&
labels
,
const
std
::
vector
<
std
::
vector
<
double
>
>&
losses
)
const
;
/*!
requires
- is_graph_labeling_problem(samples,labels) == true
- if (losses.size() != 0) then
- sizes_match(labels, losses) == true
- all_values_are_nonnegative(losses) == true
ensures
- Uses the structural_svm_graph_labeling_problem to train a graph_labeler
on the given samples/labels training pairs. The idea is to learn to
predict a label given an input sample.
- returns a function F with the following properties:
- F(new_sample) == The predicted labels for the nodes in the graph
new_sample.
- if (losses.size() == 0) then
- The values of get_loss_on_positive_class() and get_loss_on_negative_class()
are used to determine how to value mistakes on each node during training.
- The losses argument is effectively ignored if its size is zero.
- else
- Each node in the training data has its own loss value defined by the
corresponding entry of losses. In particular, this means that the
node with label labels[i][j] incurs a loss of losses[i][j] if it is
incorrectly labeled.
- The get_loss_on_positive_class() and get_loss_on_negative_class()
parameters are ignored. Only losses is used in this case.
!*/
};
// ----------------------------------------------------------------------------------------
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment