Commit 0246088a authored by Davis King's avatar Davis King

Added a per node loss interface for the structural_graph_labeling_trainer.

parent 7aab9f71
......@@ -167,20 +167,23 @@ namespace dlib
>
const graph_labeler<vector_type> train (
const dlib::array<graph_type>& samples,
const std::vector<label_type>& labels
const std::vector<label_type>& labels,
const std::vector<std::vector<double> >& losses
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(is_graph_labeling_problem(samples, labels),
"\t void structural_graph_labeling_trainer::train()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t samples.size(): " << samples.size()
<< "\n\t labels.size(): " << labels.size()
<< "\n\t this: " << this
);
DLIB_ASSERT(is_graph_labeling_problem(samples, labels) == true &&
(losses.size() == 0 || sizes_match(labels, losses) == true) &&
all_values_are_nonnegative(losses) == true,
"\t void structural_graph_labeling_trainer::train()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t samples.size(): " << samples.size()
<< "\n\t labels.size(): " << labels.size()
<< "\n\t losses.size(): " << losses.size()
<< "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses)
<< "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses)
<< "\n\t this: " << this );
std::vector<std::vector<double> > losses;
structural_svm_graph_labeling_problem<graph_type> prob(samples, labels, losses, num_threads);
if (verbose)
......@@ -189,8 +192,11 @@ namespace dlib
prob.set_c(C);
prob.set_epsilon(eps);
prob.set_max_cache_size(max_cache_size);
prob.set_loss_on_positive_class(loss_pos);
prob.set_loss_on_negative_class(loss_neg);
if (prob.get_losses().size() == 0)
{
prob.set_loss_on_positive_class(loss_pos);
prob.set_loss_on_negative_class(loss_neg);
}
matrix<double,0,1> w;
solver(prob, w, prob.get_num_edge_weights());
......@@ -201,6 +207,18 @@ namespace dlib
return graph_labeler<vector_type>(edge_weights, node_weights);
}
template <
typename graph_type
>
const graph_labeler<vector_type> train (
const dlib::array<graph_type>& samples,
const std::vector<label_type>& labels
) const
{
std::vector<std::vector<double> > losses;
return train(samples, labels, losses);
}
private:
template <typename T>
......
......@@ -212,14 +212,49 @@ namespace dlib
requires
- is_graph_labeling_problem(samples,labels) == true
ensures
- Uses the structural_svm_graph_labeling_problem to train a
graph_labeler on the given samples/labels training pairs.
The idea is to learn to predict a label given an input sample.
- Uses the structural_svm_graph_labeling_problem to train a graph_labeler
on the given samples/labels training pairs. The idea is to learn to
predict a label given an input sample.
- The values of get_loss_on_positive_class() and get_loss_on_negative_class()
are used to determine how to value mistakes on each node during training.
- returns a function F with the following properties:
- F(new_sample) == The predicted labels for the nodes in the
graph new_sample.
- F(new_sample) == The predicted labels for the nodes in the graph
new_sample.
!*/
template <
typename graph_type
>
const graph_labeler<vector_type> train (
const dlib::array<graph_type>& samples,
const std::vector<label_type>& labels,
const std::vector<std::vector<double> >& losses
) const;
/*!
requires
- is_graph_labeling_problem(samples,labels) == true
- if (losses.size() != 0) then
- sizes_match(labels, losses) == true
- all_values_are_nonnegative(losses) == true
ensures
- Uses the structural_svm_graph_labeling_problem to train a graph_labeler
on the given samples/labels training pairs. The idea is to learn to
predict a label given an input sample.
- returns a function F with the following properties:
- F(new_sample) == The predicted labels for the nodes in the graph
new_sample.
- if (losses.size() == 0) then
- The values of get_loss_on_positive_class() and get_loss_on_negative_class()
are used to determine how to value mistakes on each node during training.
- The losses argument is effectively ignored if its size is zero.
- else
- Each node in the training data has its own loss value defined by the
corresponding entry of losses. In particular, this means that the
node with label labels[i][j] incurs a loss of losses[i][j] if it is
incorrectly labeled.
- The get_loss_on_positive_class() and get_loss_on_negative_class()
parameters are ignored. Only losses is used in this case.
!*/
};
// ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment