Commit ae4c1bef authored by Davis King's avatar Davis King

Improved the feature vector caching in the structural_svm_problem object.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404246
parent 50832ee5
......@@ -35,6 +35,7 @@ namespace dlib
- if (cache.size() != 0) then
- cache.size() == get_num_samples()
- true_psis.size() == get_num_samples()
- for all i: cache[i] == the cached results of calls to separation_oracle()
for the i-th sample.
!*/
......@@ -44,6 +45,7 @@ namespace dlib
structural_svm_problem (
) :
cur_risk_lower_bound(0),
eps(0.001),
verbose(false),
skip_cache(true),
......@@ -140,8 +142,8 @@ namespace dlib
virtual bool optimization_status (
scalar_type current_objective_value,
scalar_type current_error_gap,
scalar_type ,
scalar_type ,
scalar_type current_risk_value,
scalar_type current_risk_gap,
unsigned long num_cutting_planes,
unsigned long num_iterations
) const
......@@ -149,14 +151,16 @@ namespace dlib
if (verbose)
{
using namespace std;
cout << "svm objective: " << current_objective_value << endl;
cout << "gap: " << current_error_gap << endl;
cout << "num planes: " << num_cutting_planes << endl;
cout << "iter: " << num_iterations << endl;
cout << "objective: " << current_objective_value << endl;
cout << "objective gap: " << current_error_gap << endl;
cout << "risk: " << current_risk_value << endl;
cout << "risk gap: " << current_risk_gap << endl;
cout << "num planes: " << num_cutting_planes << endl;
cout << "iter: " << num_iterations << endl;
cout << endl;
}
cur_gap = current_error_gap/get_c();
cur_risk_lower_bound = std::max<scalar_type>(current_risk_value - current_risk_gap, 0);
bool should_stop = false;
......@@ -196,10 +200,25 @@ namespace dlib
{
psi_true.set_size(w.size(),1);
psi_true = 0;
for (unsigned long i = 0; i < num; ++i)
// If the cache is enabled then populate the true_psis array. But
// in either case sum them all up and store the result in psi_true.
if (max_cache_size != 0)
{
get_truth_joint_feature_vector(i, ftemp);
sparse_vector::subtract_from(psi_true, ftemp);
true_psis.resize(num);
for (unsigned long i = 0; i < num; ++i)
{
get_truth_joint_feature_vector(i, true_psis[i]);
sparse_vector::subtract_from(psi_true, true_psis[i]);
}
}
else
{
for (unsigned long i = 0; i < num; ++i)
{
get_truth_joint_feature_vector(i, ftemp);
sparse_vector::subtract_from(psi_true, ftemp);
}
}
}
......@@ -215,8 +234,8 @@ namespace dlib
subgradient /= num;
total_loss /= num;
risk = total_loss + dot(subgradient,w);
cur_risk = risk;
// Include a sanity check that the risk is always non-negative.
risk = std::max<scalar_type>(total_loss + dot(subgradient,w), 0);
}
void separation_oracle_cached (
......@@ -232,28 +251,35 @@ namespace dlib
if (!skip_cache && max_cache_size != 0)
{
scalar_type best_val = -std::numeric_limits<scalar_type>::infinity();
scalar_type best_risk = -std::numeric_limits<scalar_type>::infinity();
unsigned long best_idx = 0;
cache_record& rec = cache[idx];
// figure out which element in the cache is the best
using sparse_vector::dot;
using dlib::dot;
const scalar_type dot_true_psi = dot(true_psis[idx], current_solution);
// figure out which element in the cache is the best (i.e. has the biggest risk)
long max_lru_count = 0;
for (unsigned long i = 0; i < rec.loss.size(); ++i)
{
using sparse_vector::dot;
const scalar_type temp = rec.loss[i] + dot(rec.psi[i], current_solution);
if (temp > best_val)
const scalar_type risk = rec.loss[i] + dot(rec.psi[i], current_solution) - dot_true_psi;
if (risk > best_risk)
{
best_val = temp;
best_risk = risk;
loss = rec.loss[i];
best_idx = i;
}
if (rec.lru_count[i] > max_lru_count)
max_lru_count = rec.lru_count[i];
}
if (best_val > cur_risk-cur_gap)
if (best_risk - cur_risk_lower_bound > eps)
{
psi = rec.psi[best_idx];
rec.lru_count[best_idx] += 1;
rec.lru_count[best_idx] = max_lru_count + 1;
return;
}
}
......@@ -267,6 +293,9 @@ namespace dlib
{
cache[idx].loss.push_back(loss);
cache[idx].psi.push_back(psi);
long max_use = 1;
if (cache[idx].lru_count.size() != 0)
max_use = max(vector_to_matrix(cache[idx].lru_count)) + 1;
cache[idx].lru_count.push_back(cache[idx].lru_count.size());
}
else
......@@ -294,12 +323,13 @@ namespace dlib
};
mutable scalar_type cur_risk;
mutable scalar_type cur_gap;
mutable scalar_type cur_risk_lower_bound;
mutable matrix_type psi_true;
scalar_type eps;
mutable bool verbose;
mutable std::vector<feature_vector_type> true_psis;
mutable std::vector<cache_record> cache;
mutable bool skip_cache;
unsigned long max_cache_size;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment