Commit 5edb76a6 authored by Davis King's avatar Davis King

Updated the sequence labeling trainer to allow the user to set different loss

values for different labels.
parent e7770786
...@@ -134,6 +134,40 @@ namespace dlib ...@@ -134,6 +134,40 @@ namespace dlib
return C; return C;
} }
double get_loss (
unsigned long label
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(label < num_labels(),
"\t void structural_sequence_labeling_trainer::get_loss()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t label: " << label
<< "\n\t num_labels(): " << num_labels()
<< "\n\t this: " << this
);
return loss_values[label];
}
void set_loss (
unsigned long label,
double value
)
{
// make sure requires clause is not broken
DLIB_ASSERT(label < num_labels() && value >= 0,
"\t void structural_sequence_labeling_trainer::set_loss()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t label: " << label
<< "\n\t num_labels(): " << num_labels()
<< "\n\t value: " << value
<< "\n\t this: " << this
);
loss_values[label] = value;
}
const sequence_labeler<feature_extractor> train( const sequence_labeler<feature_extractor> train(
const std::vector<sample_sequence_type>& x, const std::vector<sample_sequence_type>& x,
...@@ -182,6 +216,9 @@ namespace dlib ...@@ -182,6 +216,9 @@ namespace dlib
prob.set_epsilon(eps); prob.set_epsilon(eps);
prob.set_c(C); prob.set_c(C);
prob.set_max_cache_size(max_cache_size); prob.set_max_cache_size(max_cache_size);
for (unsigned long i = 0; i < loss_values.size(); ++i)
prob.set_loss(i,loss_values[i]);
solver(prob, weights); solver(prob, weights);
return sequence_labeler<feature_extractor>(weights,fe); return sequence_labeler<feature_extractor>(weights,fe);
...@@ -195,6 +232,7 @@ namespace dlib ...@@ -195,6 +232,7 @@ namespace dlib
bool verbose; bool verbose;
unsigned long num_threads; unsigned long num_threads;
unsigned long max_cache_size; unsigned long max_cache_size;
std::vector<double> loss_values;
void set_defaults () void set_defaults ()
{ {
...@@ -203,6 +241,7 @@ namespace dlib ...@@ -203,6 +241,7 @@ namespace dlib
eps = 0.1; eps = 0.1;
num_threads = 2; num_threads = 2;
max_cache_size = 40; max_cache_size = 40;
loss_values.assign(num_labels(), 1);
} }
feature_extractor fe; feature_extractor fe;
......
...@@ -190,6 +190,31 @@ namespace dlib ...@@ -190,6 +190,31 @@ namespace dlib
better generalization. better generalization.
!*/ !*/
double get_loss (
unsigned long label
) const;
/*!
requires
- label < num_labels()
ensures
- returns the loss incurred when a sequence element with the given
label is misclassified. This value controls how much we care about
correctly classifying this type of label. Larger loss values indicate
that we care more strongly than smaller values.
!*/
void set_loss (
unsigned long label,
double value
);
/*!
requires
- label < num_labels()
- value >= 0
ensures
- #get_loss(label) == value
!*/
const sequence_labeler<feature_extractor> train( const sequence_labeler<feature_extractor> train(
const std::vector<sample_sequence_type>& x, const std::vector<sample_sequence_type>& x,
const std::vector<labeled_sequence_type>& y const std::vector<labeled_sequence_type>& y
......
...@@ -114,6 +114,45 @@ namespace dlib ...@@ -114,6 +114,45 @@ namespace dlib
} }
#endif #endif
loss_values.assign(num_labels(), 1);
}
unsigned long num_labels (
) const { return fe.num_labels(); }
double get_loss (
unsigned long label
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(label < num_labels(),
"\t void structural_svm_sequence_labeling_problem::get_loss()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t label: " << label
<< "\n\t num_labels(): " << num_labels()
<< "\n\t this: " << this
);
return loss_values[label];
}
void set_loss (
unsigned long label,
double value
)
{
// make sure requires clause is not broken
DLIB_ASSERT(label < num_labels() && value >= 0,
"\t void structural_svm_sequence_labeling_problem::set_loss()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t label: " << label
<< "\n\t num_labels(): " << num_labels()
<< "\n\t value: " << value
<< "\n\t this: " << this
);
loss_values[label] = value;
} }
private: private:
...@@ -166,12 +205,14 @@ namespace dlib ...@@ -166,12 +205,14 @@ namespace dlib
const sequence_type& sequence_, const sequence_type& sequence_,
const std::vector<unsigned long>& label_, const std::vector<unsigned long>& label_,
const feature_extractor& fe_, const feature_extractor& fe_,
const matrix<double,0,1>& weights_ const matrix<double,0,1>& weights_,
const std::vector<double>& loss_values_
) : ) :
sequence(sequence_), sequence(sequence_),
label(label_), label(label_),
fe(fe_), fe(fe_),
weights(weights_) weights(weights_),
loss_values(loss_values_)
{ {
} }
...@@ -194,7 +235,7 @@ namespace dlib ...@@ -194,7 +235,7 @@ namespace dlib
double loss = 0; double loss = 0;
if (node_states(0) != label[node_id]) if (node_states(0) != label[node_id])
loss = 1; loss = loss_values[label[node_id]];
return fe_helpers::dot(weights, fe, sequence, node_states, node_id) + loss; return fe_helpers::dot(weights, fe, sequence, node_states, node_id) + loss;
} }
...@@ -203,6 +244,7 @@ namespace dlib ...@@ -203,6 +244,7 @@ namespace dlib
const std::vector<unsigned long>& label; const std::vector<unsigned long>& label;
const feature_extractor& fe; const feature_extractor& fe;
const matrix<double,0,1>& weights; const matrix<double,0,1>& weights;
const std::vector<double>& loss_values;
}; };
virtual void separation_oracle ( virtual void separation_oracle (
...@@ -213,13 +255,13 @@ namespace dlib ...@@ -213,13 +255,13 @@ namespace dlib
) const ) const
{ {
std::vector<unsigned long> y; std::vector<unsigned long> y;
find_max_factor_graph_viterbi(map_prob(samples[idx],labels[idx],fe,current_solution), y); find_max_factor_graph_viterbi(map_prob(samples[idx],labels[idx],fe,current_solution,loss_values), y);
loss = 0; loss = 0;
for (unsigned long i = 0; i < y.size(); ++i) for (unsigned long i = 0; i < y.size(); ++i)
{ {
if (y[i] != labels[idx][i]) if (y[i] != labels[idx][i])
loss += 1; loss += loss_values[labels[idx][i]];
} }
get_joint_feature_vector(samples[idx], y, psi); get_joint_feature_vector(samples[idx], y, psi);
...@@ -228,6 +270,7 @@ namespace dlib ...@@ -228,6 +270,7 @@ namespace dlib
const std::vector<sequence_type>& samples; const std::vector<sequence_type>& samples;
const std::vector<std::vector<unsigned long> >& labels; const std::vector<std::vector<unsigned long> >& labels;
const feature_extractor& fe; const feature_extractor& fe;
std::vector<double> loss_values;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -64,6 +64,40 @@ namespace dlib ...@@ -64,6 +64,40 @@ namespace dlib
- This object will use num_threads threads during the optimization - This object will use num_threads threads during the optimization
procedure. You should set this parameter equal to the number of procedure. You should set this parameter equal to the number of
available processing cores on your machine. available processing cores on your machine.
- #num_labels() == fe.num_labels()
- for all valid i: #get_loss(i) == 1
!*/
unsigned long num_labels (
) const;
/*!
ensures
- returns the number of possible labels in this learning problem
!*/
double get_loss (
unsigned long label
) const;
/*!
requires
- label < num_labels()
ensures
- returns the loss incurred when a sequence element with the given
label is misclassified. This value controls how much we care about
correctly classifying this type of label. Larger loss values indicate
that we care more strongly than smaller values.
!*/
void set_loss (
unsigned long label,
double value
);
/*!
requires
- label < num_labels()
- value >= 0
ensures
- #get_loss(label) == value
!*/ !*/
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment