Commit 374459b4 authored by Davis King's avatar Davis King

Refactored this code a little

parent 7e28ce5e
...@@ -333,9 +333,9 @@ namespace dlib ...@@ -333,9 +333,9 @@ namespace dlib
// Include a loss augmentation so that we will get the proper loss augmented // Include a loss augmentation so that we will get the proper loss augmented
// max when we use find_max_factor_graph_potts() below. // max when we use find_max_factor_graph_potts() below.
if (labels[idx][i]) if (labels[idx][i])
g.node(i).data -= loss_pos; g.node(i).data -= get_loss_for_sample(idx,i,!labels[idx][i]);
else else
g.node(i).data += loss_neg; g.node(i).data += get_loss_for_sample(idx,i,!labels[idx][i]);
for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n) for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n)
{ {
...@@ -360,20 +360,42 @@ namespace dlib ...@@ -360,20 +360,42 @@ namespace dlib
loss = 0; loss = 0;
for (unsigned long i = 0; i < labeling.size(); ++i) for (unsigned long i = 0; i < labeling.size(); ++i)
{ {
const bool true_label = labels[idx][i]; const bool predicted_label = (labeling[i]!= 0);
const bool pred_label = (labeling[i]!= 0); bool_labeling.push_back(predicted_label);
bool_labeling.push_back(pred_label); loss += get_loss_for_sample(idx, i, predicted_label);
if (true_label != pred_label) }
// compute psi
get_joint_feature_vector(samp, bool_labeling, psi);
}
double get_loss_for_sample (
long sample_idx,
long node_idx,
bool predicted_label
) const
/*!
requires
- 0 <= sample_idx < labels.size()
- 0 <= node_idx < labels[sample_idx].size()
ensures
- returns the loss incurred for predicting that the node
samples[sample_idx].node(node_idx) has a label of predicted_label.
!*/
{
const bool true_label = labels[sample_idx][node_idx];
if (true_label != predicted_label)
{ {
if (true_label == true) if (true_label == true)
loss += loss_pos; return loss_pos;
else else
loss += loss_neg; return loss_neg;
} }
else
{
// no loss for making the correct prediction.
return 0;
} }
// compute psi
get_joint_feature_vector(samp, bool_labeling, psi);
} }
const dlib::array<sample_type>& samples; const dlib::array<sample_type>& samples;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment