Commit 5edb76a6 authored by Davis King's avatar Davis King

Updated the sequence labeling trainer to allow the user to set different loss

values for different labels.
parent e7770786
......@@ -134,6 +134,40 @@ namespace dlib
return C;
}
double get_loss (
unsigned long label
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(label < num_labels(),
"\t void structural_sequence_labeling_trainer::get_loss()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t label: " << label
<< "\n\t num_labels(): " << num_labels()
<< "\n\t this: " << this
);
return loss_values[label];
}
void set_loss (
unsigned long label,
double value
)
{
// make sure requires clause is not broken
DLIB_ASSERT(label < num_labels() && value >= 0,
"\t void structural_sequence_labeling_trainer::set_loss()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t label: " << label
<< "\n\t num_labels(): " << num_labels()
<< "\n\t value: " << value
<< "\n\t this: " << this
);
loss_values[label] = value;
}
const sequence_labeler<feature_extractor> train(
const std::vector<sample_sequence_type>& x,
......@@ -182,6 +216,9 @@ namespace dlib
prob.set_epsilon(eps);
prob.set_c(C);
prob.set_max_cache_size(max_cache_size);
for (unsigned long i = 0; i < loss_values.size(); ++i)
prob.set_loss(i,loss_values[i]);
solver(prob, weights);
return sequence_labeler<feature_extractor>(weights,fe);
......@@ -195,6 +232,7 @@ namespace dlib
bool verbose;
unsigned long num_threads;
unsigned long max_cache_size;
std::vector<double> loss_values;
void set_defaults ()
{
......@@ -203,6 +241,7 @@ namespace dlib
eps = 0.1;
num_threads = 2;
max_cache_size = 40;
loss_values.assign(num_labels(), 1);
}
feature_extractor fe;
......
......@@ -190,6 +190,31 @@ namespace dlib
better generalization.
!*/
double get_loss (
unsigned long label
) const;
/*!
requires
- label < num_labels()
ensures
- returns the loss incurred when a sequence element with the given
label is misclassified. This value controls how much we care about
correctly classifying this type of label. Larger loss values indicate
that we care more strongly than smaller values.
!*/
void set_loss (
unsigned long label,
double value
);
/*!
requires
- label < num_labels()
- value >= 0
ensures
- #get_loss(label) == value
!*/
const sequence_labeler<feature_extractor> train(
const std::vector<sample_sequence_type>& x,
const std::vector<labeled_sequence_type>& y
......
......@@ -114,6 +114,45 @@ namespace dlib
}
#endif
loss_values.assign(num_labels(), 1);
}
unsigned long num_labels (
) const { return fe.num_labels(); }
double get_loss (
unsigned long label
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(label < num_labels(),
"\t void structural_svm_sequence_labeling_problem::get_loss()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t label: " << label
<< "\n\t num_labels(): " << num_labels()
<< "\n\t this: " << this
);
return loss_values[label];
}
void set_loss (
unsigned long label,
double value
)
{
// make sure requires clause is not broken
DLIB_ASSERT(label < num_labels() && value >= 0,
"\t void structural_svm_sequence_labeling_problem::set_loss()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t label: " << label
<< "\n\t num_labels(): " << num_labels()
<< "\n\t value: " << value
<< "\n\t this: " << this
);
loss_values[label] = value;
}
private:
......@@ -166,12 +205,14 @@ namespace dlib
const sequence_type& sequence_,
const std::vector<unsigned long>& label_,
const feature_extractor& fe_,
const matrix<double,0,1>& weights_
const matrix<double,0,1>& weights_,
const std::vector<double>& loss_values_
) :
sequence(sequence_),
label(label_),
fe(fe_),
weights(weights_)
weights(weights_),
loss_values(loss_values_)
{
}
......@@ -194,7 +235,7 @@ namespace dlib
double loss = 0;
if (node_states(0) != label[node_id])
loss = 1;
loss = loss_values[label[node_id]];
return fe_helpers::dot(weights, fe, sequence, node_states, node_id) + loss;
}
......@@ -203,6 +244,7 @@ namespace dlib
const std::vector<unsigned long>& label;
const feature_extractor& fe;
const matrix<double,0,1>& weights;
const std::vector<double>& loss_values;
};
virtual void separation_oracle (
......@@ -213,13 +255,13 @@ namespace dlib
) const
{
std::vector<unsigned long> y;
find_max_factor_graph_viterbi(map_prob(samples[idx],labels[idx],fe,current_solution), y);
find_max_factor_graph_viterbi(map_prob(samples[idx],labels[idx],fe,current_solution,loss_values), y);
loss = 0;
for (unsigned long i = 0; i < y.size(); ++i)
{
if (y[i] != labels[idx][i])
loss += 1;
loss += loss_values[labels[idx][i]];
}
get_joint_feature_vector(samples[idx], y, psi);
......@@ -228,6 +270,7 @@ namespace dlib
const std::vector<sequence_type>& samples;
const std::vector<std::vector<unsigned long> >& labels;
const feature_extractor& fe;
std::vector<double> loss_values;
};
// ----------------------------------------------------------------------------------------
......
......@@ -64,6 +64,40 @@ namespace dlib
- This object will use num_threads threads during the optimization
procedure. You should set this parameter equal to the number of
available processing cores on your machine.
- #num_labels() == fe.num_labels()
- for all valid i: #get_loss(i) == 1
!*/
unsigned long num_labels (
) const;
/*!
ensures
- returns the number of possible labels in this learning problem
!*/
double get_loss (
unsigned long label
) const;
/*!
requires
- label < num_labels()
ensures
- returns the loss incurred when a sequence element with the given
label is misclassified. This value controls how much we care about
correctly classifying this type of label. Larger loss values indicate
that we care more strongly than smaller values.
!*/
void set_loss (
unsigned long label,
double value
);
/*!
requires
- label < num_labels()
- value >= 0
ensures
- #get_loss(label) == value
!*/
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment