Commit 95aacdfd authored by Davis King's avatar Davis King

Improved the way the feature vector cache is used within the structural svm

solver.  This makes some things, such as the structural_object_detection_trainer,
significantly faster.
parent 06d1331c
...@@ -24,7 +24,7 @@ namespace dlib ...@@ -24,7 +24,7 @@ namespace dlib
public: public:
cache_element_structural_svm ( cache_element_structural_svm (
) : prob(0), sample_idx(0) {} ) : prob(0), sample_idx(0), last_true_risk_computed(std::numeric_limits<double>::infinity()) {}
typedef typename structural_svm_problem::scalar_type scalar_type; typedef typename structural_svm_problem::scalar_type scalar_type;
typedef typename structural_svm_problem::matrix_type matrix_type; typedef typename structural_svm_problem::matrix_type matrix_type;
...@@ -66,19 +66,23 @@ namespace dlib ...@@ -66,19 +66,23 @@ namespace dlib
void separation_oracle_cached ( void separation_oracle_cached (
const bool skip_cache, const bool skip_cache,
const scalar_type& cur_risk_lower_bound, const scalar_type& saved_current_risk_gap,
const matrix_type& current_solution, const matrix_type& current_solution,
scalar_type& out_loss, scalar_type& out_loss,
feature_vector_type& out_psi feature_vector_type& out_psi
) const ) const
{ {
if (!skip_cache && prob->get_max_cache_size() != 0) const bool cache_enabled = prob->get_max_cache_size() != 0;
// Don't waste time computing this if the cache isn't going to be used.
const scalar_type dot_true_psi = cache_enabled ? dot(true_psi, current_solution) : 0;
if (!skip_cache && cache_enabled)
{ {
scalar_type best_risk = -std::numeric_limits<scalar_type>::infinity(); scalar_type best_risk = -std::numeric_limits<scalar_type>::infinity();
unsigned long best_idx = 0; unsigned long best_idx = 0;
const scalar_type dot_true_psi = dot(true_psi, current_solution);
// figure out which element in the cache is the best (i.e. has the biggest risk) // figure out which element in the cache is the best (i.e. has the biggest risk)
long max_lru_count = 0; long max_lru_count = 0;
...@@ -95,7 +99,11 @@ namespace dlib ...@@ -95,7 +99,11 @@ namespace dlib
max_lru_count = lru_count[i]; max_lru_count = lru_count[i];
} }
if (best_risk - cur_risk_lower_bound > prob->get_epsilon()) // Check if the best psi vector in the cache is still good enough to use as
// a proxy for the true separation oracle. If the risk value has dropped
// by enough to get into the stopping condition then the best psi isn't
// good enough.
if (best_risk + saved_current_risk_gap-prob->get_epsilon() > last_true_risk_computed)
{ {
out_psi = psi[best_idx]; out_psi = psi[best_idx];
lru_count[best_idx] = max_lru_count + 1; lru_count[best_idx] = max_lru_count + 1;
...@@ -106,11 +114,13 @@ namespace dlib ...@@ -106,11 +114,13 @@ namespace dlib
prob->separation_oracle(sample_idx, current_solution, out_loss, out_psi); prob->separation_oracle(sample_idx, current_solution, out_loss, out_psi);
if (prob->get_max_cache_size() == 0) if (!cache_enabled)
return; return;
compact_sparse_vector(out_psi); compact_sparse_vector(out_psi);
last_true_risk_computed = out_loss + dot(out_psi, current_solution) - dot_true_psi;
// if the cache is full // if the cache is full
if (loss.size() >= prob->get_max_cache_size()) if (loss.size() >= prob->get_max_cache_size())
{ {
...@@ -167,6 +177,7 @@ namespace dlib ...@@ -167,6 +177,7 @@ namespace dlib
mutable std::vector<scalar_type> loss; mutable std::vector<scalar_type> loss;
mutable std::vector<feature_vector_type> psi; mutable std::vector<feature_vector_type> psi;
mutable std::vector<long> lru_count; mutable std::vector<long> lru_count;
mutable double last_true_risk_computed;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -200,7 +211,7 @@ namespace dlib ...@@ -200,7 +211,7 @@ namespace dlib
structural_svm_problem ( structural_svm_problem (
) : ) :
cur_risk_lower_bound(0), saved_current_risk_gap(0),
eps(0.001), eps(0.001),
verbose(false), verbose(false),
skip_cache(true), skip_cache(true),
...@@ -315,7 +326,8 @@ namespace dlib ...@@ -315,7 +326,8 @@ namespace dlib
cout << endl; cout << endl;
} }
cur_risk_lower_bound = std::max<scalar_type>(current_risk_value - current_risk_gap, 0); saved_current_risk_gap = std::max<scalar_type>(current_risk_value - current_risk_gap, 0);
saved_current_risk_gap = current_risk_gap;
bool should_stop = false; bool should_stop = false;
...@@ -401,7 +413,7 @@ namespace dlib ...@@ -401,7 +413,7 @@ namespace dlib
) const ) const
{ {
cache[idx].separation_oracle_cached(skip_cache, cache[idx].separation_oracle_cached(skip_cache,
cur_risk_lower_bound, saved_current_risk_gap,
current_solution, current_solution,
loss, loss,
psi); psi);
...@@ -409,7 +421,7 @@ namespace dlib ...@@ -409,7 +421,7 @@ namespace dlib
private: private:
mutable scalar_type cur_risk_lower_bound; mutable scalar_type saved_current_risk_gap;
mutable matrix_type psi_true; mutable matrix_type psi_true;
scalar_type eps; scalar_type eps;
mutable bool verbose; mutable bool verbose;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment