Commit f67cc883 authored by Davis King's avatar Davis King

Made the structural svm solver use its cache elements to try and mitigate the

errors made by an approximate separation oracle.  In particular, the solver
will now check the output of the separation oracle against the cache and if the
cache gives a better value, even when we would otherwise not use the cache, the
cache value is used.  Similarly, we can output the truth psi vector to avoid
outputting a psi with a negative risk.  All this stuff only happens when the
cache is enabled, if its disabled then the outputs of the separation oracle are
used without any kind of modification.
parent 0a288456
......@@ -77,15 +77,12 @@ namespace dlib
// Don't waste time computing this if the cache isn't going to be used.
const scalar_type dot_true_psi = cache_enabled ? dot(true_psi, current_solution) : 0;
if (!skip_cache && cache_enabled)
scalar_type best_risk = -std::numeric_limits<scalar_type>::infinity();
unsigned long best_idx = 0;
long max_lru_count = 0;
if (cache_enabled)
{
scalar_type best_risk = -std::numeric_limits<scalar_type>::infinity();
unsigned long best_idx = 0;
// figure out which element in the cache is the best (i.e. has the biggest risk)
long max_lru_count = 0;
for (unsigned long i = 0; i < loss.size(); ++i)
{
const scalar_type risk = loss[i] + dot(psi[i], current_solution) - dot_true_psi;
......@@ -99,15 +96,19 @@ namespace dlib
max_lru_count = lru_count[i];
}
// Check if the best psi vector in the cache is still good enough to use as
// a proxy for the true separation oracle. If the risk value has dropped
// by enough to get into the stopping condition then the best psi isn't
// good enough.
if (best_risk + saved_current_risk_gap > last_true_risk_computed)
if (!skip_cache)
{
out_psi = psi[best_idx];
lru_count[best_idx] = max_lru_count + 1;
return;
// Check if the best psi vector in the cache is still good enough to use as
// a proxy for the true separation oracle. If the risk value has dropped
// by enough to get into the stopping condition then the best psi isn't
// good enough.
if (best_risk + saved_current_risk_gap > last_true_risk_computed &&
best_risk >= 0)
{
out_psi = psi[best_idx];
lru_count[best_idx] = max_lru_count + 1;
return;
}
}
}
......@@ -121,8 +122,25 @@ namespace dlib
last_true_risk_computed = out_loss + dot(out_psi, current_solution) - dot_true_psi;
// If the separation oracle is only solved approximately then the result might
// not be as good as just selecting true_psi as the output. So here we check
// if that is the case.
if (last_true_risk_computed < 0 && best_risk < 0)
{
out_psi = true_psi;
out_loss = 0;
}
// Alternatively, an approximate separation oracle might not do as well as just
// selecting from the cache. So if that is the case when just take the best
// element from the cache.
else if (last_true_risk_computed < best_risk)
{
out_psi = psi[best_idx];
out_loss = loss[best_idx];
lru_count[best_idx] = max_lru_count + 1;
}
// if the cache is full
if (loss.size() >= prob->get_max_cache_size())
else if (loss.size() >= prob->get_max_cache_size())
{
// find least recently used cache entry for idx-th sample
const long i = index_of_min(mat(lru_count));
......@@ -138,6 +156,8 @@ namespace dlib
}
else
{
// In this case we just append the new psi into the cache.
loss.push_back(out_loss);
psi.push_back(out_psi);
long max_use = 1;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment