Commit 3a5f99f4 authored by Davis King's avatar Davis King

Updated the interface to the structural_graph_labeling_trainer so the user

can set the per class loss to whatever they want.
parent 42c123f2
......@@ -32,6 +32,8 @@ namespace dlib
eps = 0.1;
num_threads = 2;
max_cache_size = 40;
loss_pos = 1.0;
loss_neg = 1.0;
}
void set_num_threads (
......@@ -124,6 +126,42 @@ namespace dlib
return C;
}
void set_loss_on_positive_class (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss >= 0,
"\t structural_graph_labeling_trainer::set_loss_on_positive_class()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t loss: " << loss
<< "\n\t this: " << this );
loss_pos = loss;
}
void set_loss_on_negative_class (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss >= 0,
"\t structural_graph_labeling_trainer::set_loss_on_negative_class()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t loss: " << loss
<< "\n\t this: " << this );
loss_neg = loss;
}
double get_loss_on_negative_class (
) const { return loss_neg; }
double get_loss_on_positive_class (
) const { return loss_pos; }
template <
typename graph_type
>
......@@ -150,6 +188,8 @@ namespace dlib
prob.set_c(C);
prob.set_epsilon(eps);
prob.set_max_cache_size(max_cache_size);
prob.set_loss_on_positive_class(loss_pos);
prob.set_loss_on_negative_class(loss_neg);
matrix<double,0,1> w;
solver(prob, w, prob.get_num_edge_weights());
......@@ -203,6 +243,8 @@ namespace dlib
bool verbose;
unsigned long num_threads;
unsigned long max_cache_size;
double loss_pos;
double loss_neg;
};
// ----------------------------------------------------------------------------------------
......
......@@ -48,6 +48,8 @@ namespace dlib
- #get_epsilon() == 0.1
- #get_num_threads() == 2
- #get_max_cache_size() == 40
- #get_loss_on_positive_class() == 1.0
- #get_loss_on_negative_class() == 1.0
!*/
void set_num_threads (
......@@ -159,6 +161,46 @@ namespace dlib
better generalization.
!*/
void set_loss_on_positive_class (
double loss
);
/*!
requires
- loss >= 0
ensures
- #get_loss_on_positive_class() == loss
!*/
void set_loss_on_negative_class (
double loss
);
/*!
requires
- loss >= 0
ensures
- #get_loss_on_negative_class() == loss
!*/
double get_loss_on_positive_class (
) const;
/*!
ensures
- returns the loss incurred when a graph node which is supposed to have
a label of true gets misclassified. This value controls how much we care
about correctly classifying nodes which should be labeled as true. Larger
loss values indicate that we care more strongly than smaller values.
!*/
double get_loss_on_negative_class (
) const;
/*!
ensures
- returns the loss incurred when a graph node which is supposed to have
a label of false gets misclassified. This value controls how much we care
about correctly classifying nodes which should be labeled as false. Larger
loss values indicate that we care more strongly than smaller values.
!*/
template <
typename graph_type
>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment