Commit a7e55c79 authored by Davis King's avatar Davis King

Added the option to force the last weight to 1 to the assignment learning

tools.
parent 097354f7
......@@ -185,6 +185,19 @@ namespace dlib
return loss_per_missed_association;
}
bool forces_last_weight_to_1 (
) const
{
return last_weight_1;
}
void force_last_weight_to_1 (
bool should_last_weight_be_1
)
{
last_weight_1 = should_last_weight_be_1;
}
const assignment_function<feature_extractor> train (
const std::vector<sample_type>& samples,
const std::vector<label_type>& labels
......@@ -230,7 +243,10 @@ namespace dlib
// Take the min here because we want to prevent the user from accidentally
// forcing the bias term to be non-negative.
const unsigned long num_nonneg = std::min(fe.num_features(),num_nonnegative_weights(fe));
solver(prob, weights, num_nonneg);
if (last_weight_1)
solver(prob, weights, num_nonneg, fe.num_features()-1);
else
solver(prob, weights, num_nonneg);
const double bias = weights(weights.size()-1);
return assignment_function<feature_extractor>(colm(weights,0,weights.size()-1), bias,fe,force_assignment);
......@@ -249,6 +265,7 @@ namespace dlib
unsigned long max_cache_size;
double loss_per_false_association;
double loss_per_missed_association;
bool last_weight_1;
void set_defaults ()
{
......@@ -260,6 +277,7 @@ namespace dlib
max_cache_size = 5;
loss_per_false_association = 1;
loss_per_missed_association = 1;
last_weight_1 = false;
}
feature_extractor fe;
......
......@@ -54,6 +54,7 @@ namespace dlib
- #forces_assignment() == false
- #get_loss_per_false_association() == 1
- #get_loss_per_missed_association() == 1
- #forces_last_weight_to_1() == false
!*/
explicit structural_assignment_trainer (
......@@ -70,6 +71,7 @@ namespace dlib
- #forces_assignment() == false
- #get_loss_per_false_association() == 1
- #get_loss_per_missed_association() == 1
- #forces_last_weight_to_1() == false
!*/
const feature_extractor& get_feature_extractor (
......@@ -244,6 +246,26 @@ namespace dlib
assignment_functions generated by this object.
!*/
bool forces_last_weight_to_1 (
) const;
/*!
ensures
- returns true if this trainer has the constraint that the last weight in
the learned parameter vector must be 1. This is the weight corresponding
to the feature in the training vectors with the highest dimension.
- Forcing the last weight to 1 also disables the bias and therefore the
get_bias() field of the learned assignment_function will be 0 when
forces_last_weight_to_1() == true.
!*/
void force_last_weight_to_1 (
bool should_last_weight_be_1
);
/*!
ensures
- #forces_last_weight_to_1() == should_last_weight_be_1
!*/
const assignment_function<feature_extractor> train (
const std::vector<sample_type>& samples,
const std::vector<label_type>& labels
......@@ -262,6 +284,9 @@ namespace dlib
new_sample.first match up with the elements of new_sample.second.
- F.forces_assignment() == forces_assignment()
- F.get_feature_extractor() == get_feature_extractor()
- if (forces_last_weight_to_1()) then
- F.get_bias() == 0
- F.get_weights()(F.get_weights().size()-1) == 1
!*/
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment