Commit ea2a5184 authored by Davis King's avatar Davis King

Cleaned up the code a little by pulling the caching logic out into its

own class.
parent 007e218e
......@@ -17,10 +17,134 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename matrix_type,
typename feature_vector_type_ = matrix_type
typename structural_svm_problem
>
class structural_svm_problem : public oca_problem<matrix_type>
class cache_element_structural_svm
{
public:
cache_element_structural_svm (
) : prob(0), sample_idx(0) {}
typedef typename structural_svm_problem::scalar_type scalar_type;
typedef typename structural_svm_problem::matrix_type matrix_type;
typedef typename structural_svm_problem::feature_vector_type feature_vector_type;
void init (
const structural_svm_problem* prob_,
const long idx
)
/*!
ensures
- This object will be a cache for the idx-th sample in the given
structural_svm_problem.
!*/
{
prob = prob_;
sample_idx = idx;
loss.clear();
psi.clear();
lru_count.clear();
prob->get_truth_joint_feature_vector(idx, true_psi);
}
void get_truth_joint_feature_vector_cached (
feature_vector_type& psi
) const
{
psi = true_psi;
}
void separation_oracle_cached (
const bool skip_cache,
const scalar_type& cur_risk_lower_bound,
const matrix_type& current_solution,
scalar_type& out_loss,
feature_vector_type& out_psi
) const
{
if (!skip_cache)
{
scalar_type best_risk = -std::numeric_limits<scalar_type>::infinity();
unsigned long best_idx = 0;
using sparse_vector::dot;
using dlib::dot;
const scalar_type dot_true_psi = dot(true_psi, current_solution);
// figure out which element in the cache is the best (i.e. has the biggest risk)
long max_lru_count = 0;
for (unsigned long i = 0; i < loss.size(); ++i)
{
const scalar_type risk = loss[i] + dot(psi[i], current_solution) - dot_true_psi;
if (risk > best_risk)
{
best_risk = risk;
out_loss = loss[i];
best_idx = i;
}
if (lru_count[i] > max_lru_count)
max_lru_count = lru_count[i];
}
if (best_risk - cur_risk_lower_bound > prob->get_epsilon())
{
out_psi = psi[best_idx];
lru_count[best_idx] = max_lru_count + 1;
return;
}
}
prob->separation_oracle(sample_idx, current_solution, out_loss, out_psi);
// if the cache is full
if (loss.size() >= prob->get_max_cache_size())
{
// find least recently used cache entry for idx-th sample
const long i = index_of_min(vector_to_matrix(lru_count));
// save our new data in the cache
loss[i] = out_loss;
psi[i] = out_psi;
const long max_use = max(vector_to_matrix(lru_count));
// Make sure this new cache entry has the best lru count since we have used
// it most recently.
lru_count[i] = max_use + 1;
}
else
{
loss.push_back(out_loss);
psi.push_back(out_psi);
long max_use = 1;
if (lru_count.size() != 0)
max_use = max(vector_to_matrix(lru_count)) + 1;
lru_count.push_back(lru_count.size());
}
}
const structural_svm_problem* prob;
long sample_idx;
mutable feature_vector_type true_psi;
mutable std::vector<scalar_type> loss;
mutable std::vector<feature_vector_type> psi;
mutable std::vector<long> lru_count;
};
// ----------------------------------------------------------------------------------------
template <
typename matrix_type_,
typename feature_vector_type_ = matrix_type_
>
class structural_svm_problem : public oca_problem<matrix_type_>
{
public:
/*!
......@@ -35,11 +159,11 @@ namespace dlib
- if (cache.size() != 0) then
- cache.size() == get_num_samples()
- true_psis.size() == get_num_samples()
- for all i: cache[i] == the cached results of calls to separation_oracle()
for the i-th sample.
!*/
typedef matrix_type_ matrix_type;
typedef typename matrix_type::type scalar_type;
typedef feature_vector_type_ feature_vector_type;
......@@ -193,36 +317,30 @@ namespace dlib
feature_vector_type ftemp;
const unsigned long num = get_num_samples();
// initialize psi_true and a few other things if we haven't done so already.
if (psi_true.size() == 0)
{
// initialize the cache if necessary.
if (cache.size() == 0 && max_cache_size != 0)
{
cache.resize(get_num_samples());
for (unsigned long i = 0; i < cache.size(); ++i)
cache[i].init(this,i);
}
// initialize psi_true if necessary.
if (psi_true.size() == 0)
{
psi_true.set_size(w.size(),1);
psi_true = 0;
// If the cache is enabled then populate the true_psis array. But
// in either case sum them all up and store the result in psi_true.
if (max_cache_size != 0)
{
true_psis.resize(num);
for (unsigned long i = 0; i < num; ++i)
{
get_truth_joint_feature_vector(i, true_psis[i]);
sparse_vector::subtract_from(psi_true, true_psis[i]);
}
}
else
{
for (unsigned long i = 0; i < num; ++i)
{
if (cache.size() == 0)
get_truth_joint_feature_vector(i, ftemp);
else
cache[i].get_truth_joint_feature_vector_cached(ftemp);
sparse_vector::subtract_from(psi_true, ftemp);
}
}
}
subgradient = psi_true;
scalar_type total_loss = 0;
......@@ -259,90 +377,29 @@ namespace dlib
feature_vector_type& psi
) const
{
if (!skip_cache && max_cache_size != 0)
{
scalar_type best_risk = -std::numeric_limits<scalar_type>::infinity();
unsigned long best_idx = 0;
cache_record& rec = cache[idx];
using sparse_vector::dot;
using dlib::dot;
const scalar_type dot_true_psi = dot(true_psis[idx], current_solution);
// figure out which element in the cache is the best (i.e. has the biggest risk)
long max_lru_count = 0;
for (unsigned long i = 0; i < rec.loss.size(); ++i)
{
const scalar_type risk = rec.loss[i] + dot(rec.psi[i], current_solution) - dot_true_psi;
if (risk > best_risk)
{
best_risk = risk;
loss = rec.loss[i];
best_idx = i;
}
if (rec.lru_count[i] > max_lru_count)
max_lru_count = rec.lru_count[i];
}
if (best_risk - cur_risk_lower_bound > eps)
if (cache.size() == 0)
{
psi = rec.psi[best_idx];
rec.lru_count[best_idx] = max_lru_count + 1;
return;
}
}
separation_oracle(idx, current_solution, loss, psi);
if (cache.size() != 0)
{
if (cache[idx].loss.size() < max_cache_size)
{
cache[idx].loss.push_back(loss);
cache[idx].psi.push_back(psi);
long max_use = 1;
if (cache[idx].lru_count.size() != 0)
max_use = max(vector_to_matrix(cache[idx].lru_count)) + 1;
cache[idx].lru_count.push_back(cache[idx].lru_count.size());
}
else
{
// find least recently used cache entry for idx-th sample
const long i = index_of_min(vector_to_matrix(cache[idx].lru_count));
// save our new data in the cache
cache[idx].loss[i] = loss;
cache[idx].psi[i] = psi;
const long max_use = max(vector_to_matrix(cache[idx].lru_count));
// Make sure this new cache entry has the best lru count since we have used
// it most recently.
cache[idx].lru_count[i] = max_use + 1;
}
cache[idx].separation_oracle_cached(skip_cache,
cur_risk_lower_bound,
current_solution,
loss,
psi);
}
}
private:
struct cache_record
{
std::vector<scalar_type> loss;
std::vector<feature_vector_type> psi;
std::vector<long> lru_count;
};
mutable scalar_type cur_risk_lower_bound;
mutable matrix_type psi_true;
scalar_type eps;
mutable bool verbose;
mutable std::vector<feature_vector_type> true_psis;
mutable std::vector<cache_record> cache;
mutable std::vector<cache_element_structural_svm<structural_svm_problem> > cache;
mutable bool skip_cache;
unsigned long max_cache_size;
......
......@@ -13,15 +13,15 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename matrix_type,
typename feature_vector_type_ = matrix_type
typename matrix_type_,
typename feature_vector_type_ = matrix_type_
>
class structural_svm_problem : public oca_problem<matrix_type>
class structural_svm_problem : public oca_problem<matrix_type_>
{
public:
/*!
REQUIREMENTS ON matrix_type
- matrix_type == a dlib::matrix capable of storing column vectors
REQUIREMENTS ON matrix_type_
- matrix_type_ == a dlib::matrix capable of storing column vectors
REQUIREMENTS ON feature_vector_type_
- feature_vector_type_ == a dlib::matrix capable of storing column vectors
......@@ -81,6 +81,7 @@ namespace dlib
paper.
!*/
typedef matrix_type_ matrix_type;
typedef typename matrix_type::type scalar_type;
typedef feature_vector_type_ feature_vector_type;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment