Commit a9c7de47 authored by Davis King's avatar Davis King

Added user settable loss to the association trainer

parent 78088d40
...@@ -143,6 +143,48 @@ namespace dlib ...@@ -143,6 +143,48 @@ namespace dlib
force_assignment = new_value; force_assignment = new_value;
} }
void set_loss_per_false_association (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss > 0,
"\t void structural_assignment_trainer::set_loss_per_false_association(loss)"
<< "\n\t Invalid inputs were given to this function "
<< "\n\t loss: " << loss
<< "\n\t this: " << this
);
loss_per_false_association = loss;
}
double get_loss_per_false_association (
) const
{
return loss_per_false_association;
}
void set_loss_per_missed_association (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss > 0,
"\t void structural_assignment_trainer::set_loss_per_missed_association(loss)"
<< "\n\t Invalid inputs were given to this function "
<< "\n\t loss: " << loss
<< "\n\t this: " << this
);
loss_per_missed_association = loss;
}
double get_loss_per_missed_association (
) const
{
return loss_per_missed_association;
}
const assignment_function<feature_extractor> train ( const assignment_function<feature_extractor> train (
const std::vector<sample_type>& samples, const std::vector<sample_type>& samples,
const std::vector<label_type>& labels const std::vector<label_type>& labels
...@@ -173,7 +215,8 @@ namespace dlib ...@@ -173,7 +215,8 @@ namespace dlib
structural_svm_assignment_problem<feature_extractor> prob(samples,labels, fe, force_assignment, num_threads); structural_svm_assignment_problem<feature_extractor> prob(samples,labels, fe, force_assignment, num_threads,
loss_per_false_association, loss_per_missed_association);
if (verbose) if (verbose)
prob.be_verbose(); prob.be_verbose();
...@@ -204,6 +247,8 @@ namespace dlib ...@@ -204,6 +247,8 @@ namespace dlib
bool verbose; bool verbose;
unsigned long num_threads; unsigned long num_threads;
unsigned long max_cache_size; unsigned long max_cache_size;
double loss_per_false_association;
double loss_per_missed_association;
void set_defaults () void set_defaults ()
{ {
...@@ -213,6 +258,8 @@ namespace dlib ...@@ -213,6 +258,8 @@ namespace dlib
eps = 0.01; eps = 0.01;
num_threads = 2; num_threads = 2;
max_cache_size = 5; max_cache_size = 5;
loss_per_false_association = 1;
loss_per_missed_association = 1;
} }
feature_extractor fe; feature_extractor fe;
......
...@@ -52,6 +52,8 @@ namespace dlib ...@@ -52,6 +52,8 @@ namespace dlib
- #get_max_cache_size() == 5 - #get_max_cache_size() == 5
- #get_feature_extractor() == a default initialized feature_extractor - #get_feature_extractor() == a default initialized feature_extractor
- #forces_assignment() == false - #forces_assignment() == false
- #get_loss_per_false_association() == 1
- #get_loss_per_missed_association() == 1
!*/ !*/
explicit structural_assignment_trainer ( explicit structural_assignment_trainer (
...@@ -66,6 +68,8 @@ namespace dlib ...@@ -66,6 +68,8 @@ namespace dlib
- #get_max_cache_size() == 40 - #get_max_cache_size() == 40
- #get_feature_extractor() == fe - #get_feature_extractor() == fe
- #forces_assignment() == false - #forces_assignment() == false
- #get_loss_per_false_association() == 1
- #get_loss_per_missed_association() == 1
!*/ !*/
const feature_extractor& get_feature_extractor ( const feature_extractor& get_feature_extractor (
...@@ -132,6 +136,46 @@ namespace dlib ...@@ -132,6 +136,46 @@ namespace dlib
of 0 means caching is not used at all. of 0 means caching is not used at all.
!*/ !*/
void set_loss_per_false_association (
double loss
);
/*!
requires
- loss > 0
ensures
- #get_loss_per_false_association() == loss
!*/
double get_loss_per_false_association (
) const;
/*!
ensures
- returns the amount of loss experienced for associating two objects
together that shouldn't be associated. If you care more about avoiding
accidental associations than ensuring all possible associations are
identified then then you can increase this value.
!*/
void set_loss_per_missed_association (
double loss
);
/*!
requires
- loss > 0
ensures
- #get_loss_per_missed_association() == loss
!*/
double get_loss_per_missed_association (
) const;
/*!
ensures
- returns the amount of loss experienced for failing to associate two
objects that are supposed to be associated. If you care more about
getting all the associations than avoiding accidentally associating
objects that shouldn't be associated then you can increase this value.
!*/
void be_verbose ( void be_verbose (
); );
/*! /*!
......
...@@ -63,16 +63,27 @@ namespace dlib ...@@ -63,16 +63,27 @@ namespace dlib
const std::vector<label_type>& labels_, const std::vector<label_type>& labels_,
const feature_extractor& fe_, const feature_extractor& fe_,
bool force_assignment_, bool force_assignment_,
unsigned long num_threads = 2 unsigned long num_threads,
const double loss_per_false_association_,
const double loss_per_missed_association_
) : ) :
structural_svm_problem_threaded<matrix_type,feature_vector_type>(num_threads), structural_svm_problem_threaded<matrix_type,feature_vector_type>(num_threads),
samples(samples_), samples(samples_),
labels(labels_), labels(labels_),
fe(fe_), fe(fe_),
force_assignment(force_assignment_) force_assignment(force_assignment_),
loss_per_false_association(loss_per_false_association_),
loss_per_missed_association(loss_per_missed_association_)
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
#ifdef ENABLE_ASSERTS #ifdef ENABLE_ASSERTS
DLIB_ASSERT(loss_per_false_association > 0 && loss_per_missed_association > 0,
"\t structural_svm_assignment_problem::structural_svm_assignment_problem()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t loss_per_false_association: " << loss_per_false_association
<< "\n\t loss_per_missed_association: " << loss_per_missed_association
<< "\n\t this: " << this
);
if (force_assignment) if (force_assignment)
{ {
DLIB_ASSERT(is_forced_assignment_problem(samples, labels), DLIB_ASSERT(is_forced_assignment_problem(samples, labels),
...@@ -193,8 +204,6 @@ namespace dlib ...@@ -193,8 +204,6 @@ namespace dlib
} }
cost.set_size(size, size); cost.set_size(size, size);
const double loss_for_false_association = 1;
const double loss_for_missed_association = 1;
typename feature_extractor::feature_vector_type feats; typename feature_extractor::feature_vector_type feats;
// now fill out the cost assignment matrix // now fill out the cost assignment matrix
...@@ -213,7 +222,7 @@ namespace dlib ...@@ -213,7 +222,7 @@ namespace dlib
// add in the loss since this corresponds to an incorrect prediction. // add in the loss since this corresponds to an incorrect prediction.
if (c != labels[idx][r]) if (c != labels[idx][r])
{ {
cost(r,c) += loss_for_false_association; cost(r,c) += loss_per_false_association;
} }
} }
else else
...@@ -221,7 +230,7 @@ namespace dlib ...@@ -221,7 +230,7 @@ namespace dlib
if (labels[idx][r] == -1) if (labels[idx][r] == -1)
cost(r,c) = 0; cost(r,c) = 0;
else else
cost(r,c) = loss_for_missed_association; cost(r,c) = loss_per_missed_association;
} }
} }
...@@ -254,9 +263,9 @@ namespace dlib ...@@ -254,9 +263,9 @@ namespace dlib
if (assignment[i] != labels[idx][i]) if (assignment[i] != labels[idx][i])
{ {
if (assignment[i] == -1) if (assignment[i] == -1)
loss += loss_for_missed_association; loss += loss_per_missed_association;
else else
loss += loss_for_false_association; loss += loss_per_false_association;
} }
} }
...@@ -267,6 +276,8 @@ namespace dlib ...@@ -267,6 +276,8 @@ namespace dlib
const std::vector<label_type>& labels; const std::vector<label_type>& labels;
const feature_extractor& fe; const feature_extractor& fe;
bool force_assignment; bool force_assignment;
const double loss_per_false_association;
const double loss_per_missed_association;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -45,10 +45,14 @@ namespace dlib ...@@ -45,10 +45,14 @@ namespace dlib
const std::vector<label_type>& labels, const std::vector<label_type>& labels,
const feature_extractor& fe, const feature_extractor& fe,
bool force_assignment, bool force_assignment,
unsigned long num_threads = 2 unsigned long num_threads,
const double loss_per_false_association,
const double loss_per_missed_association
); );
/*! /*!
requires requires
- loss_per_false_association > 0
- loss_per_missed_association > 0
- is_assignment_problem(samples,labels) == true - is_assignment_problem(samples,labels) == true
- if (force_assignment) then - if (force_assignment) then
- is_forced_assignment_problem(samples,labels) == true - is_forced_assignment_problem(samples,labels) == true
...@@ -63,6 +67,12 @@ namespace dlib ...@@ -63,6 +67,12 @@ namespace dlib
- This object will use num_threads threads during the optimization - This object will use num_threads threads during the optimization
procedure. You should set this parameter equal to the number of procedure. You should set this parameter equal to the number of
available processing cores on your machine. available processing cores on your machine.
- When solving the structural SVM problem, we will use
loss_per_false_association as the loss for incorrectly associating
objects that shouldn't be associated.
- When solving the structural SVM problem, we will use
loss_per_missed_association as the loss for failing to associate to
objects that are supposed to be associated with each other.
!*/ !*/
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment