Commit 2d7e320a authored by Davis King's avatar Davis King

Added user settable loss to the structural_track_association_trainer.

parent d6c10818
......@@ -91,6 +91,48 @@ namespace dlib
return max_cache_size;
}
void set_loss_per_false_association (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss > 0,
"\t void structural_track_association_trainer::set_loss_per_false_association(loss)"
<< "\n\t Invalid inputs were given to this function "
<< "\n\t loss: " << loss
<< "\n\t this: " << this
);
loss_per_false_association = loss;
}
double get_loss_per_false_association (
) const
{
return loss_per_false_association;
}
void set_loss_per_track_break (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss > 0,
"\t void structural_track_association_trainer::set_loss_per_track_break(loss)"
<< "\n\t Invalid inputs were given to this function "
<< "\n\t loss: " << loss
<< "\n\t this: " << this
);
loss_per_track_break = loss;
}
double get_loss_per_track_break (
) const
{
return loss_per_track_break;
}
void be_verbose (
)
{
......@@ -178,6 +220,8 @@ namespace dlib
trainer.set_max_cache_size(max_cache_size);
trainer.set_num_threads(num_threads);
trainer.set_oca(solver);
trainer.set_loss_per_missed_association(loss_per_track_break);
trainer.set_loss_per_false_association(loss_per_false_association);
std::vector<std::pair<std::vector<detection_type>, std::vector<track_type> > > assignment_samples;
std::vector<std::vector<long> > labels;
......@@ -338,6 +382,8 @@ namespace dlib
unsigned long num_threads;
unsigned long max_cache_size;
bool learn_nonnegative_weights;
double loss_per_track_break;
double loss_per_false_association;
void set_defaults ()
{
......@@ -347,6 +393,8 @@ namespace dlib
num_threads = 2;
max_cache_size = 5;
learn_nonnegative_weights = false;
loss_per_track_break = 1;
loss_per_false_association = 1;
}
};
......
......@@ -39,6 +39,8 @@ namespace dlib
- #get_num_threads() == 2
- #get_max_cache_size() == 5
- #learns_nonnegative_weights() == false
- #get_loss_per_track_break() == 1
- #get_loss_per_false_association() == 1
!*/
void set_num_threads (
......@@ -113,6 +115,45 @@ namespace dlib
- this object will not print anything to standard out
!*/
void set_loss_per_false_association (
double loss
);
/*!
requires
- loss > 0
ensures
- #get_loss_per_false_association() == loss
!*/
double get_loss_per_false_association (
) const;
/*!
ensures
- returns the amount of loss experienced for assigning a detection to the
wrong track. If you care more about avoiding false associations than
avoiding track breaks then you can increase this value.
!*/
void set_loss_per_track_break (
double loss
);
/*!
requires
- loss > 0
ensures
- #get_loss_per_track_break() == loss
!*/
double get_loss_per_track_break (
) const;
/*!
ensures
- returns the amount of loss experienced for incorrectly assigning a
detection to a new track instead of assigning it to its existing track.
If you care more about avoiding track breaks than avoiding things like
track swaps then you can increase this value.
!*/
void set_oca (
const oca& item
);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment