Commit efb1a12d authored by Davis King's avatar Davis King

Added the ability for the user to set the per class loss.

parent 2c45ab5e
......@@ -149,6 +149,8 @@ namespace dlib
<< "\n\t labels.size(): " << labels.size()
<< "\n\t this: " << this );
loss_pos = 1.0;
loss_neg = 1.0;
// figure out how many dimensions are in the node and edge vectors.
node_dims = 0;
......@@ -172,6 +174,41 @@ namespace dlib
return edge_dims;
}
void set_loss_on_positive_class (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss >= 0,
"\t structural_svm_graph_labeling_problem::set_loss_on_positive_class()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t loss: " << loss
<< "\n\t this: " << this );
loss_pos = loss;
}
void set_loss_on_negative_class (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss >= 0,
"\t structural_svm_graph_labeling_problem::set_loss_on_negative_class()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t loss: " << loss
<< "\n\t this: " << this );
loss_pos = loss;
}
double get_loss_on_negative_class (
) const { return loss_neg; }
double get_loss_on_positive_class (
) const { return loss_pos; }
private:
virtual long get_num_dimensions (
) const
......@@ -303,9 +340,9 @@ namespace dlib
// max when we use find_max_factor_graph_potts() below.
const bool label_i = (labels[idx][i]!=0);
if (label_i)
g.node(i).data -= 1.0;
g.node(i).data -= loss_pos;
else
g.node(i).data += 1.0;
g.node(i).data += loss_neg;
for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n)
{
......@@ -331,7 +368,12 @@ namespace dlib
const bool true_label = (labels[idx][i]!= 0);
const bool pred_label = (labeling[i]!= 0);
if (true_label != pred_label)
++loss;
{
if (true_label == true)
loss += loss_pos;
else
loss += loss_neg;
}
}
// compute psi
......@@ -343,6 +385,8 @@ namespace dlib
long node_dims;
long edge_dims;
double loss_pos;
double loss_neg;
};
// ----------------------------------------------------------------------------------------
......
......@@ -120,6 +120,46 @@ namespace dlib
part of the total weight vector. You can do this by passing get_num_edge_weights()
to the third argument to oca::operator().
!*/
void set_loss_on_positive_class (
double loss
);
/*!
requires
- loss >= 0
ensures
- #get_loss_on_positive_class() == loss
!*/
void set_loss_on_negative_class (
double loss
);
/*!
requires
- loss >= 0
ensures
- #get_loss_on_negative_class() == loss
!*/
double get_loss_on_positive_class (
) const;
/*!
ensures
- returns the loss incurred when a graph node which is supposed to have
a label of true gets misclassified. This value controls how much we care
about correctly classifying nodes which should be labeled as true. Larger
loss values indicate that we care more strongly than smaller values.
!*/
double get_loss_on_negative_class (
) const;
/*!
ensures
- returns the loss incurred when a graph node which is supposed to have
a label of false gets misclassified. This value controls how much we care
about correctly classifying nodes which should be labeled as false. Larger
loss values indicate that we care more strongly than smaller values.
!*/
};
// ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment