Commit aaeb52ba authored by Davis King's avatar Davis King

Updated the interface to allow the user to set different loss values for

false alarming vs getting a correct detection.
parent 1de36ea2
......@@ -27,11 +27,15 @@ namespace dlib
const feature_extractor& fe_
) : trainer(impl_ss::feature_extractor<feature_extractor>(fe_))
{
loss_per_missed_segment = 1;
loss_per_false_alarm = 1;
}
structural_sequence_segmentation_trainer (
)
{
loss_per_missed_segment = 1;
loss_per_false_alarm = 1;
}
const feature_extractor& get_feature_extractor (
......@@ -127,6 +131,63 @@ namespace dlib
return trainer.get_c();
}
void set_loss_per_missed_segment (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss >= 0,
"\t void structural_sequence_segmentation_trainer::set_loss_per_missed_segment(loss)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t loss: " << loss
<< "\n\t this: " << this
);
loss_per_missed_segment = loss;
if (feature_extractor::use_BIO_model)
{
trainer.set_loss(impl_ss::BEGIN, loss_per_missed_segment);
trainer.set_loss(impl_ss::INSIDE, loss_per_missed_segment);
}
else
{
trainer.set_loss(impl_ss::BEGIN, loss_per_missed_segment);
trainer.set_loss(impl_ss::INSIDE, loss_per_missed_segment);
trainer.set_loss(impl_ss::LAST, loss_per_missed_segment);
trainer.set_loss(impl_ss::UNIT, loss_per_missed_segment);
}
}
double get_loss_per_missed_segment (
) const
{
return loss_per_missed_segment;
}
void set_loss_per_false_alarm (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss >= 0,
"\t void structural_sequence_segmentation_trainer::set_loss_per_false_alarm(loss)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t loss: " << loss
<< "\n\t this: " << this
);
loss_per_false_alarm = loss;
trainer.set_loss(impl_ss::OUTSIDE, loss_per_false_alarm);
}
double get_loss_per_false_alarm (
) const
{
return loss_per_false_alarm;
}
const sequence_segmenter<feature_extractor> train(
const std::vector<sample_sequence_type>& x,
const std::vector<segmented_sequence_type>& y
......@@ -198,6 +259,8 @@ namespace dlib
private:
structural_sequence_labeling_trainer<impl_ss::feature_extractor<feature_extractor> > trainer;
double loss_per_missed_segment;
double loss_per_false_alarm;
};
// ----------------------------------------------------------------------------------------
......
......@@ -47,6 +47,8 @@ namespace dlib
- #get_num_threads() == 2
- #get_max_cache_size() == 40
- #get_feature_extractor() == a default initialized feature_extractor
- #get_loss_per_missed_segment() == 1
- #get_loss_per_false_alarm() == 1
!*/
explicit structural_sequence_segmentation_trainer (
......@@ -60,6 +62,8 @@ namespace dlib
- #get_num_threads() == 2
- #get_max_cache_size() == 40
- #get_feature_extractor() == fe
- #get_loss_per_missed_segment() == 1
- #get_loss_per_false_alarm() == 1
!*/
const feature_extractor& get_feature_extractor (
......@@ -178,6 +182,44 @@ namespace dlib
generalization.
!*/
void set_loss_per_missed_segment (
double loss
);
/*!
requires
- loss >= 0
ensures
- #get_loss_per_missed_segment() == loss
!*/
double get_loss_per_missed_segment (
) const;
/*!
ensures
- returns the amount of loss incurred for failing to detect a segment. The
larger the loss the more important it is to detect all the segments.
!*/
void set_loss_per_false_alarm (
double loss
);
/*!
requires
- loss >= 0
ensures
- #get_loss_per_false_alarm() == loss
!*/
double get_loss_per_false_alarm (
) const;
/*!
ensures
- returns the amount of loss incurred for outputting a false detection. The
larger the loss the more important it is to avoid outputting false
detections.
!*/
const sequence_segmenter<feature_extractor> train(
const std::vector<sample_sequence_type>& x,
const std::vector<segmented_sequence_type>& y
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment