Commit bcf6bd9a authored by Davis King's avatar Davis King

Cleaned up loss_metric_ code a little

parent 34cce6da
......@@ -872,6 +872,17 @@ namespace dlib
typedef unsigned long training_label_type;
typedef matrix<float,0,1> output_label_type;
loss_metric_() = default;
loss_metric_(
float margin_,
float dist_thresh_
) : margin(margin_), dist_thresh(dist_thresh_)
{
DLIB_CASSERT(margin_ > 0);
DLIB_CASSERT(dist_thresh_ > 0);
}
template <
typename SUB_TYPE,
typename label_iterator
......@@ -900,8 +911,9 @@ namespace dlib
}
}
float get_margin() const { return 0.1; }
float get_distance_threshold() const { return 0.75; }
float get_margin() const { return margin; }
float get_distance_threshold() const { return dist_thresh; }
template <
typename const_label_iterator,
......@@ -927,8 +939,6 @@ namespace dlib
grad.nc() == 1);
const float margin = get_margin();
const float dist_thresh = get_distance_threshold();
temp.set_size(output_tensor.num_samples(), output_tensor.num_samples());
grad_mul.copy_size(temp);
......@@ -1043,31 +1053,52 @@ namespace dlib
return loss;
}
friend void serialize(const loss_metric_& , std::ostream& out)
friend void serialize(const loss_metric_& item, std::ostream& out)
{
serialize("loss_metric_", out);
serialize("loss_metric_2", out);
serialize(item.margin, out);
serialize(item.dist_thresh, out);
}
friend void deserialize(loss_metric_& , std::istream& in)
friend void deserialize(loss_metric_& item, std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "loss_metric_")
if (version == "loss_metric_")
{
// These values used to be hard coded, so for this version of the metric
// learning loss we just use these values.
item.margin = 0.1;
item.dist_thresh = 0.75;
return;
}
else if (version == "loss_metric_2")
{
deserialize(item.margin, in);
deserialize(item.dist_thresh, in);
}
else
{
throw serialization_error("Unexpected version found while deserializing dlib::loss_metric_. Instead found " + version);
}
}
friend std::ostream& operator<<(std::ostream& out, const loss_metric_& )
friend std::ostream& operator<<(std::ostream& out, const loss_metric_& item )
{
out << "loss_metric";
out << "loss_metric (margin="<<item.margin<<", distance_threshold="<<item.dist_thresh<<")";
return out;
}
friend void to_xml(const loss_metric_& /*item*/, std::ostream& out)
friend void to_xml(const loss_metric_& item, std::ostream& out)
{
out << "<loss_metric/>";
out << "<loss_metric margin='"<<item.margin<<"' distance_threshold='"<<item.dist_thresh<<"'/>";
}
private:
float margin = 0.04;
float dist_thresh = 0.6;
// These variables are only here to avoid being reallocated over and over in
// compute_loss_value_and_gradient()
mutable resizable_tensor temp, grad_mul;
......
......@@ -559,6 +559,27 @@ namespace dlib
typedef unsigned long training_label_type;
typedef matrix<float,0,1> output_label_type;
loss_metric_(
);
/*!
ensures
- #get_margin() == 0.04
- #get_distance_threshold() == 0.6
!*/
loss_metric_(
float margin,
float dist_thresh
);
/*!
requires
- margin > 0
- dist_thresh > 0
ensures
- #get_margin() == margin
- #get_distance_threshold() == dist_thresh
!*/
template <
typename SUB_TYPE,
typename label_iterator
......@@ -581,14 +602,14 @@ namespace dlib
given to this function, one for each sample in the input_tensor.
!*/
float get_margin() const { return 0.1; }
float get_margin() const;
/*!
ensures
- returns the margin value used by the loss function. See the discussion
in WHAT THIS OBJECT REPRESENTS for details.
!*/
float get_distance_threshold() const { return 0.75; }
float get_distance_threshold() const;
/*!
ensures
- returns the distance threshold value used by the loss function. See the discussion
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment