Commit 124e0ff4 authored by Davis King's avatar Davis King

Removed loss_metric_hardish_

parent 4219185d
......@@ -1075,223 +1075,6 @@ namespace dlib
template <typename SUBNET>
using loss_metric = add_loss_layer<loss_metric_, SUBNET>;
// ----------------------------------------------------------------------------------------
class loss_metric_hardish_
{
public:
typedef unsigned long training_label_type;
typedef matrix<float,0,1> output_label_type;
template <
typename SUB_TYPE,
typename label_iterator
>
void to_label (
const tensor& input_tensor,
const SUB_TYPE& sub,
label_iterator iter
) const
{
const tensor& output_tensor = sub.get_output();
DLIB_CASSERT(sub.sample_expansion_factor() == 1);
DLIB_CASSERT(input_tensor.num_samples() != 0);
DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
DLIB_CASSERT(output_tensor.nr() == 1 &&
output_tensor.nc() == 1);
const float* p = output_tensor.host();
for (long i = 0; i < output_tensor.num_samples(); ++i)
{
*iter = mat(p,output_tensor.k(),1);
++iter;
p += output_tensor.k();
}
}
template <
typename const_label_iterator,
typename SUBNET
>
double compute_loss_value_and_gradient (
const tensor& input_tensor,
const_label_iterator truth,
SUBNET& sub
) const
{
const tensor& output_tensor = sub.get_output();
tensor& grad = sub.get_gradient_input();
DLIB_CASSERT(sub.sample_expansion_factor() == 1);
DLIB_CASSERT(input_tensor.num_samples() != 0);
DLIB_CASSERT(input_tensor.num_samples()%sub.sample_expansion_factor() == 0);
DLIB_CASSERT(input_tensor.num_samples() == grad.num_samples());
DLIB_CASSERT(input_tensor.num_samples() == output_tensor.num_samples());
DLIB_CASSERT(output_tensor.nr() == 1 &&
output_tensor.nc() == 1);
DLIB_CASSERT(grad.nr() == 1 &&
grad.nc() == 1);
const float margin = 0.1;
const float dist_thresh = 0.75;
temp.set_size(output_tensor.num_samples(), output_tensor.num_samples());
grad_mul.copy_size(temp);
tt::gemm(0, temp, 1, output_tensor, false, output_tensor, true);
const float* d = temp.host();
double loss = 0;
double num_pos_samps = 0.0001;
double num_active_neg_samps = 0.0001;
for (long r = 0; r < temp.num_samples(); ++r)
{
auto xx = d[r*temp.num_samples() + r];
const auto x_label = *(truth + r);
for (long c = r+1; c < temp.num_samples(); ++c)
{
const auto y_label = *(truth + c);
if (x_label == y_label)
{
++num_pos_samps;
}
else
{
// Figure out what distance threshold, when applied to the negative pairs,
// causes there to be an equal number of positive and negative pairs.
auto yy = d[c*temp.num_samples() + c];
auto xy = d[r*temp.num_samples() + c];
// compute the distance between x and y samples.
auto d2 = xx + yy - 2*xy;
if (d2 < 0)
d2 = 0;
else
d2 = std::sqrt(d2);
if (d2 < dist_thresh+margin)
++num_active_neg_samps;
}
}
}
// The whole objective function is multiplied by this to scale the loss
// relative to the number of things in the mini-batch.
const double scale = 0.5/num_pos_samps;
DLIB_CASSERT(num_pos_samps>=1, "Make sure each mini-batch contains both positive pairs and negative pairs");
// We will pick which negative pairs to include in the objective by selecting
// randomly among the pairs that violate the margin. This means we should pick
// this fraction of the margin violators if we want the pos/neg ratio to be
// balanced.
const double neg_select_probability = num_pos_samps/num_active_neg_samps;
// loop over all the pairs of training samples and compute the loss and
// gradients. Note that we only use the hardest negative pairs and that in
// particular we pick the number of negative pairs equal to the number of
// positive pairs so everything is balanced.
float* gm = grad_mul.host();
for (long r = 0; r < temp.num_samples(); ++r)
{
gm[r*temp.num_samples() + r] = 0;
const auto x_label = *(truth + r);
auto xx = d[r*temp.num_samples() + r];
for (long c = 0; c < temp.num_samples(); ++c)
{
if (r==c)
continue;
const auto y_label = *(truth + c);
auto yy = d[c*temp.num_samples() + c];
auto xy = d[r*temp.num_samples() + c];
// compute the distance between x and y samples.
auto d2 = xx + yy - 2*xy;
if (d2 <= 0)
d2 = 0;
else
d2 = std::sqrt(d2);
if (x_label == y_label)
{
// Things with the same label should have distances < dist_thresh between
// them. If not then we experience non-zero loss.
if (d2 < dist_thresh-margin)
{
gm[r*temp.num_samples() + c] = 0;
}
else
{
loss += scale*(d2 - (dist_thresh-margin));
gm[r*temp.num_samples() + r] += scale/d2;
gm[r*temp.num_samples() + c] = -scale/d2;
}
}
else
{
// Things with different labels should have distances > dist_thresh between
// them. If not then we experience non-zero loss.
if (d2 > dist_thresh+margin || rnd.get_random_double() > neg_select_probability)
{
gm[r*temp.num_samples() + c] = 0;
}
else
{
loss += scale*((dist_thresh+margin) - d2);
// don't divide by zero (or a really small number)
d2 = std::max(d2, 0.001f);
gm[r*temp.num_samples() + r] -= scale/d2;
gm[r*temp.num_samples() + c] = scale/d2;
}
}
}
}
tt::gemm(0, grad, 1, grad_mul, false, output_tensor, false);
return loss;
}
friend void serialize(const loss_metric_hardish_& , std::ostream& out)
{
serialize("loss_metric_hardish_", out);
}
friend void deserialize(loss_metric_hardish_& , std::istream& in)
{
std::string version;
deserialize(version, in);
if (version != "loss_metric_hardish_")
throw serialization_error("Unexpected version found while deserializing dlib::loss_metric_hardish_. Instead found: " + version);
}
friend std::ostream& operator<<(std::ostream& out, const loss_metric_hardish_& )
{
out << "loss_metric_hardish";
return out;
}
friend void to_xml(const loss_metric_hardish_& /*item*/, std::ostream& out)
{
out << "<loss_metric_hardish/>";
}
private:
// These variables are only here to avoid being reallocated over and over in
// compute_loss_value_and_gradient()
mutable resizable_tensor temp, grad_mul;
mutable dlib::rand rnd;
};
template <typename SUBNET>
using loss_metric_hardish = add_loss_layer<loss_metric_hardish_, SUBNET>;
// ----------------------------------------------------------------------------------------
class loss_mean_squared_
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment