Commit ae4c1bef authored by Davis King's avatar Davis King

Improved the feature vector caching in the structural_svm_problem object.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404246
parent 50832ee5
...@@ -35,6 +35,7 @@ namespace dlib ...@@ -35,6 +35,7 @@ namespace dlib
- if (cache.size() != 0) then - if (cache.size() != 0) then
- cache.size() == get_num_samples() - cache.size() == get_num_samples()
- true_psis.size() == get_num_samples()
- for all i: cache[i] == the cached results of calls to separation_oracle() - for all i: cache[i] == the cached results of calls to separation_oracle()
for the i-th sample. for the i-th sample.
!*/ !*/
...@@ -44,6 +45,7 @@ namespace dlib ...@@ -44,6 +45,7 @@ namespace dlib
structural_svm_problem ( structural_svm_problem (
) : ) :
cur_risk_lower_bound(0),
eps(0.001), eps(0.001),
verbose(false), verbose(false),
skip_cache(true), skip_cache(true),
...@@ -140,8 +142,8 @@ namespace dlib ...@@ -140,8 +142,8 @@ namespace dlib
virtual bool optimization_status ( virtual bool optimization_status (
scalar_type current_objective_value, scalar_type current_objective_value,
scalar_type current_error_gap, scalar_type current_error_gap,
scalar_type , scalar_type current_risk_value,
scalar_type , scalar_type current_risk_gap,
unsigned long num_cutting_planes, unsigned long num_cutting_planes,
unsigned long num_iterations unsigned long num_iterations
) const ) const
...@@ -149,14 +151,16 @@ namespace dlib ...@@ -149,14 +151,16 @@ namespace dlib
if (verbose) if (verbose)
{ {
using namespace std; using namespace std;
cout << "svm objective: " << current_objective_value << endl; cout << "objective: " << current_objective_value << endl;
cout << "gap: " << current_error_gap << endl; cout << "objective gap: " << current_error_gap << endl;
cout << "num planes: " << num_cutting_planes << endl; cout << "risk: " << current_risk_value << endl;
cout << "iter: " << num_iterations << endl; cout << "risk gap: " << current_risk_gap << endl;
cout << "num planes: " << num_cutting_planes << endl;
cout << "iter: " << num_iterations << endl;
cout << endl; cout << endl;
} }
cur_gap = current_error_gap/get_c(); cur_risk_lower_bound = std::max<scalar_type>(current_risk_value - current_risk_gap, 0);
bool should_stop = false; bool should_stop = false;
...@@ -196,10 +200,25 @@ namespace dlib ...@@ -196,10 +200,25 @@ namespace dlib
{ {
psi_true.set_size(w.size(),1); psi_true.set_size(w.size(),1);
psi_true = 0; psi_true = 0;
for (unsigned long i = 0; i < num; ++i)
// If the cache is enabled then populate the true_psis array. But
// in either case sum them all up and store the result in psi_true.
if (max_cache_size != 0)
{ {
get_truth_joint_feature_vector(i, ftemp); true_psis.resize(num);
sparse_vector::subtract_from(psi_true, ftemp); for (unsigned long i = 0; i < num; ++i)
{
get_truth_joint_feature_vector(i, true_psis[i]);
sparse_vector::subtract_from(psi_true, true_psis[i]);
}
}
else
{
for (unsigned long i = 0; i < num; ++i)
{
get_truth_joint_feature_vector(i, ftemp);
sparse_vector::subtract_from(psi_true, ftemp);
}
} }
} }
...@@ -215,8 +234,8 @@ namespace dlib ...@@ -215,8 +234,8 @@ namespace dlib
subgradient /= num; subgradient /= num;
total_loss /= num; total_loss /= num;
risk = total_loss + dot(subgradient,w); // Include a sanity check that the risk is always non-negative.
cur_risk = risk; risk = std::max<scalar_type>(total_loss + dot(subgradient,w), 0);
} }
void separation_oracle_cached ( void separation_oracle_cached (
...@@ -232,28 +251,35 @@ namespace dlib ...@@ -232,28 +251,35 @@ namespace dlib
if (!skip_cache && max_cache_size != 0) if (!skip_cache && max_cache_size != 0)
{ {
scalar_type best_val = -std::numeric_limits<scalar_type>::infinity(); scalar_type best_risk = -std::numeric_limits<scalar_type>::infinity();
unsigned long best_idx = 0; unsigned long best_idx = 0;
cache_record& rec = cache[idx]; cache_record& rec = cache[idx];
// figure out which element in the cache is the best using sparse_vector::dot;
using dlib::dot;
const scalar_type dot_true_psi = dot(true_psis[idx], current_solution);
// figure out which element in the cache is the best (i.e. has the biggest risk)
long max_lru_count = 0;
for (unsigned long i = 0; i < rec.loss.size(); ++i) for (unsigned long i = 0; i < rec.loss.size(); ++i)
{ {
using sparse_vector::dot; const scalar_type risk = rec.loss[i] + dot(rec.psi[i], current_solution) - dot_true_psi;
const scalar_type temp = rec.loss[i] + dot(rec.psi[i], current_solution); if (risk > best_risk)
if (temp > best_val)
{ {
best_val = temp; best_risk = risk;
loss = rec.loss[i]; loss = rec.loss[i];
best_idx = i; best_idx = i;
} }
if (rec.lru_count[i] > max_lru_count)
max_lru_count = rec.lru_count[i];
} }
if (best_val > cur_risk-cur_gap) if (best_risk - cur_risk_lower_bound > eps)
{ {
psi = rec.psi[best_idx]; psi = rec.psi[best_idx];
rec.lru_count[best_idx] += 1; rec.lru_count[best_idx] = max_lru_count + 1;
return; return;
} }
} }
...@@ -267,6 +293,9 @@ namespace dlib ...@@ -267,6 +293,9 @@ namespace dlib
{ {
cache[idx].loss.push_back(loss); cache[idx].loss.push_back(loss);
cache[idx].psi.push_back(psi); cache[idx].psi.push_back(psi);
long max_use = 1;
if (cache[idx].lru_count.size() != 0)
max_use = max(vector_to_matrix(cache[idx].lru_count)) + 1;
cache[idx].lru_count.push_back(cache[idx].lru_count.size()); cache[idx].lru_count.push_back(cache[idx].lru_count.size());
} }
else else
...@@ -294,12 +323,13 @@ namespace dlib ...@@ -294,12 +323,13 @@ namespace dlib
}; };
mutable scalar_type cur_risk; mutable scalar_type cur_risk_lower_bound;
mutable scalar_type cur_gap;
mutable matrix_type psi_true; mutable matrix_type psi_true;
scalar_type eps; scalar_type eps;
mutable bool verbose; mutable bool verbose;
mutable std::vector<feature_vector_type> true_psis;
mutable std::vector<cache_record> cache; mutable std::vector<cache_record> cache;
mutable bool skip_cache; mutable bool skip_cache;
unsigned long max_cache_size; unsigned long max_cache_size;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment