Commit 39c9a234 authored by Davis King's avatar Davis King

Simplified code a little bit.

parent 1f254ef3
...@@ -47,6 +47,7 @@ namespace dlib ...@@ -47,6 +47,7 @@ namespace dlib
psi.clear(); psi.clear();
lru_count.clear(); lru_count.clear();
if (prob->get_max_cache_size() != 0)
prob->get_truth_joint_feature_vector(idx, true_psi); prob->get_truth_joint_feature_vector(idx, true_psi);
} }
...@@ -54,7 +55,10 @@ namespace dlib ...@@ -54,7 +55,10 @@ namespace dlib
feature_vector_type& psi feature_vector_type& psi
) const ) const
{ {
if (prob->get_max_cache_size() != 0)
psi = true_psi; psi = true_psi;
else
prob->get_truth_joint_feature_vector(sample_idx, psi);
} }
void separation_oracle_cached ( void separation_oracle_cached (
...@@ -320,25 +324,18 @@ namespace dlib ...@@ -320,25 +324,18 @@ namespace dlib
feature_vector_type ftemp; feature_vector_type ftemp;
const unsigned long num = get_num_samples(); const unsigned long num = get_num_samples();
// initialize the cache if necessary. // initialize the cache and compute psi_true.
if (cache.size() == 0 && max_cache_size != 0) if (cache.size() == 0)
{ {
cache.resize(get_num_samples()); cache.resize(get_num_samples());
for (unsigned long i = 0; i < cache.size(); ++i) for (unsigned long i = 0; i < cache.size(); ++i)
cache[i].init(this,i); cache[i].init(this,i);
}
// initialize psi_true if necessary.
if (psi_true.size() == 0)
{
psi_true.set_size(w.size(),1); psi_true.set_size(w.size(),1);
psi_true = 0; psi_true = 0;
for (unsigned long i = 0; i < num; ++i) for (unsigned long i = 0; i < num; ++i)
{ {
if (cache.size() == 0)
get_truth_joint_feature_vector(i, ftemp);
else
cache[i].get_truth_joint_feature_vector_cached(ftemp); cache[i].get_truth_joint_feature_vector_cached(ftemp);
sparse_vector::subtract_from(psi_true, ftemp); sparse_vector::subtract_from(psi_true, ftemp);
...@@ -379,12 +376,6 @@ namespace dlib ...@@ -379,12 +376,6 @@ namespace dlib
scalar_type& loss, scalar_type& loss,
feature_vector_type& psi feature_vector_type& psi
) const ) const
{
if (cache.size() == 0)
{
separation_oracle(idx, current_solution, loss, psi);
}
else
{ {
cache[idx].separation_oracle_cached(skip_cache, cache[idx].separation_oracle_cached(skip_cache,
cur_risk_lower_bound, cur_risk_lower_bound,
...@@ -392,7 +383,6 @@ namespace dlib ...@@ -392,7 +383,6 @@ namespace dlib
loss, loss,
psi); psi);
} }
}
private: private:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment