Commit 39c9a234 authored by Davis King's avatar Davis King

Simplified code a little bit.

parent 1f254ef3
...@@ -47,14 +47,18 @@ namespace dlib ...@@ -47,14 +47,18 @@ namespace dlib
psi.clear(); psi.clear();
lru_count.clear(); lru_count.clear();
prob->get_truth_joint_feature_vector(idx, true_psi); if (prob->get_max_cache_size() != 0)
prob->get_truth_joint_feature_vector(idx, true_psi);
} }
void get_truth_joint_feature_vector_cached ( void get_truth_joint_feature_vector_cached (
feature_vector_type& psi feature_vector_type& psi
) const ) const
{ {
psi = true_psi; if (prob->get_max_cache_size() != 0)
psi = true_psi;
else
prob->get_truth_joint_feature_vector(sample_idx, psi);
} }
void separation_oracle_cached ( void separation_oracle_cached (
...@@ -320,26 +324,19 @@ namespace dlib ...@@ -320,26 +324,19 @@ namespace dlib
feature_vector_type ftemp; feature_vector_type ftemp;
const unsigned long num = get_num_samples(); const unsigned long num = get_num_samples();
// initialize the cache if necessary. // initialize the cache and compute psi_true.
if (cache.size() == 0 && max_cache_size != 0) if (cache.size() == 0)
{ {
cache.resize(get_num_samples()); cache.resize(get_num_samples());
for (unsigned long i = 0; i < cache.size(); ++i) for (unsigned long i = 0; i < cache.size(); ++i)
cache[i].init(this,i); cache[i].init(this,i);
}
// initialize psi_true if necessary.
if (psi_true.size() == 0)
{
psi_true.set_size(w.size(),1); psi_true.set_size(w.size(),1);
psi_true = 0; psi_true = 0;
for (unsigned long i = 0; i < num; ++i) for (unsigned long i = 0; i < num; ++i)
{ {
if (cache.size() == 0) cache[i].get_truth_joint_feature_vector_cached(ftemp);
get_truth_joint_feature_vector(i, ftemp);
else
cache[i].get_truth_joint_feature_vector_cached(ftemp);
sparse_vector::subtract_from(psi_true, ftemp); sparse_vector::subtract_from(psi_true, ftemp);
} }
...@@ -380,18 +377,11 @@ namespace dlib ...@@ -380,18 +377,11 @@ namespace dlib
feature_vector_type& psi feature_vector_type& psi
) const ) const
{ {
if (cache.size() == 0) cache[idx].separation_oracle_cached(skip_cache,
{ cur_risk_lower_bound,
separation_oracle(idx, current_solution, loss, psi); current_solution,
} loss,
else psi);
{
cache[idx].separation_oracle_cached(skip_cache,
cur_risk_lower_bound,
current_solution,
loss,
psi);
}
} }
private: private:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment