Commit 0246088a authored by Davis King's avatar Davis King

Added a per node loss interface for the structural_graph_labeling_trainer.

parent 7aab9f71
...@@ -167,20 +167,23 @@ namespace dlib ...@@ -167,20 +167,23 @@ namespace dlib
> >
const graph_labeler<vector_type> train ( const graph_labeler<vector_type> train (
const dlib::array<graph_type>& samples, const dlib::array<graph_type>& samples,
const std::vector<label_type>& labels const std::vector<label_type>& labels,
const std::vector<std::vector<double> >& losses
) const ) const
{ {
// make sure requires clause is not broken DLIB_ASSERT(is_graph_labeling_problem(samples, labels) == true &&
DLIB_ASSERT(is_graph_labeling_problem(samples, labels), (losses.size() == 0 || sizes_match(labels, losses) == true) &&
"\t void structural_graph_labeling_trainer::train()" all_values_are_nonnegative(losses) == true,
<< "\n\t Invalid inputs were given to this function." "\t void structural_graph_labeling_trainer::train()"
<< "\n\t samples.size(): " << samples.size() << "\n\t Invalid inputs were given to this function."
<< "\n\t labels.size(): " << labels.size() << "\n\t samples.size(): " << samples.size()
<< "\n\t this: " << this << "\n\t labels.size(): " << labels.size()
); << "\n\t losses.size(): " << losses.size()
<< "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses)
<< "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses)
<< "\n\t this: " << this );
std::vector<std::vector<double> > losses;
structural_svm_graph_labeling_problem<graph_type> prob(samples, labels, losses, num_threads); structural_svm_graph_labeling_problem<graph_type> prob(samples, labels, losses, num_threads);
if (verbose) if (verbose)
...@@ -189,8 +192,11 @@ namespace dlib ...@@ -189,8 +192,11 @@ namespace dlib
prob.set_c(C); prob.set_c(C);
prob.set_epsilon(eps); prob.set_epsilon(eps);
prob.set_max_cache_size(max_cache_size); prob.set_max_cache_size(max_cache_size);
prob.set_loss_on_positive_class(loss_pos); if (prob.get_losses().size() == 0)
prob.set_loss_on_negative_class(loss_neg); {
prob.set_loss_on_positive_class(loss_pos);
prob.set_loss_on_negative_class(loss_neg);
}
matrix<double,0,1> w; matrix<double,0,1> w;
solver(prob, w, prob.get_num_edge_weights()); solver(prob, w, prob.get_num_edge_weights());
...@@ -201,6 +207,18 @@ namespace dlib ...@@ -201,6 +207,18 @@ namespace dlib
return graph_labeler<vector_type>(edge_weights, node_weights); return graph_labeler<vector_type>(edge_weights, node_weights);
} }
template <
typename graph_type
>
const graph_labeler<vector_type> train (
const dlib::array<graph_type>& samples,
const std::vector<label_type>& labels
) const
{
std::vector<std::vector<double> > losses;
return train(samples, labels, losses);
}
private: private:
template <typename T> template <typename T>
......
...@@ -212,14 +212,49 @@ namespace dlib ...@@ -212,14 +212,49 @@ namespace dlib
requires requires
- is_graph_labeling_problem(samples,labels) == true - is_graph_labeling_problem(samples,labels) == true
ensures ensures
- Uses the structural_svm_graph_labeling_problem to train a - Uses the structural_svm_graph_labeling_problem to train a graph_labeler
graph_labeler on the given samples/labels training pairs. on the given samples/labels training pairs. The idea is to learn to
The idea is to learn to predict a label given an input sample. predict a label given an input sample.
- The values of get_loss_on_positive_class() and get_loss_on_negative_class()
are used to determine how to value mistakes on each node during training.
- returns a function F with the following properties: - returns a function F with the following properties:
- F(new_sample) == The predicted labels for the nodes in the - F(new_sample) == The predicted labels for the nodes in the graph
graph new_sample. new_sample.
!*/ !*/
template <
typename graph_type
>
const graph_labeler<vector_type> train (
const dlib::array<graph_type>& samples,
const std::vector<label_type>& labels,
const std::vector<std::vector<double> >& losses
) const;
/*!
requires
- is_graph_labeling_problem(samples,labels) == true
- if (losses.size() != 0) then
- sizes_match(labels, losses) == true
- all_values_are_nonnegative(losses) == true
ensures
- Uses the structural_svm_graph_labeling_problem to train a graph_labeler
on the given samples/labels training pairs. The idea is to learn to
predict a label given an input sample.
- returns a function F with the following properties:
- F(new_sample) == The predicted labels for the nodes in the graph
new_sample.
- if (losses.size() == 0) then
- The values of get_loss_on_positive_class() and get_loss_on_negative_class()
are used to determine how to value mistakes on each node during training.
- The losses argument is effectively ignored if its size is zero.
- else
- Each node in the training data has its own loss value defined by the
corresponding entry of losses. In particular, this means that the
node with label labels[i][j] incurs a loss of losses[i][j] if it is
incorrectly labeled.
- The get_loss_on_positive_class() and get_loss_on_negative_class()
parameters are ignored. Only losses is used in this case.
!*/
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment