Commit 29f22685 authored by Davis King's avatar Davis King

Fixed set_prior() so it works with sparse vectors in addition to dense vectors.

parent d7f207f2
......@@ -37,13 +37,15 @@ namespace dlib
const std::vector<ranking_pair<sample_type> >& samples_,
const bool be_verbose_,
const scalar_type eps_,
const unsigned long max_iter
const unsigned long max_iter,
const unsigned long dims_
) :
samples(samples_),
C(C_),
be_verbose(be_verbose_),
eps(eps_),
max_iterations(max_iter)
max_iterations(max_iter),
dims(dims_)
{
}
......@@ -56,7 +58,7 @@ namespace dlib
virtual long get_num_dimensions (
) const
{
return max_index_plus_one(samples);
return dims;
}
virtual bool optimization_status (
......@@ -173,6 +175,7 @@ namespace dlib
const bool be_verbose;
const scalar_type eps;
const unsigned long max_iterations;
const unsigned long dims;
};
// ----------------------------------------------------------------------------------------
......@@ -187,11 +190,12 @@ namespace dlib
const std::vector<ranking_pair<sample_type> >& samples,
const bool be_verbose,
const scalar_type eps,
const unsigned long max_iterations
const unsigned long max_iterations,
const unsigned long dims
)
{
return oca_problem_ranking_svm<matrix_type, sample_type>(
C, samples, be_verbose, eps, max_iterations);
C, samples, be_verbose, eps, max_iterations, dims);
}
// ----------------------------------------------------------------------------------------
......@@ -346,7 +350,7 @@ namespace dlib
<< "\n\t this: " << this
);
prior = prior_.basis_vectors(0);
prior = sparse_to_dense(prior_.basis_vectors(0));
learn_nonnegative_weights = false;
last_weight_1 = false;
}
......@@ -421,13 +425,29 @@ namespace dlib
<< "\n\t prior.size(): " << prior.size()
);
}
solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations),
w,
prior);
const unsigned long dims = std::max(num_dims, (unsigned long)prior.size());
// In the case of sparse sample vectors, it is possible that the input
// vector dimensionality is larger than the prior vector dimensionality.
// We need to check for this case and pad prior with zeros if it is the
// case.
if ((unsigned long)prior.size() < dims)
{
matrix<scalar_type,0,1> prior_temp = join_cols(prior, zeros_matrix<scalar_type>(dims-prior.size(),1));
solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations, dims),
w,
prior_temp);
}
else
{
solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations, dims),
w,
prior);
}
}
else
{
solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations),
solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations, num_dims),
w,
num_nonnegative,
force_weight_1_idx);
......
......@@ -6,6 +6,7 @@
#include <string>
#include <cstdlib>
#include <ctime>
#include <map>
#include "tester.h"
......@@ -107,6 +108,41 @@ namespace
DLIB_TEST(df.basis_vectors(0)(2) > 0);
}
// ----------------------------------------------------------------------------------------
void run_prior_sparse_test()
{
print_spinner();
typedef std::map<unsigned long,double> sample_type;
typedef sparse_linear_kernel<sample_type> kernel_type;
svm_rank_trainer<kernel_type> trainer;
ranking_pair<sample_type> data;
sample_type samp;
samp[0] = 1; data.relevant.push_back(samp); samp.clear();
samp[1] = 1; data.nonrelevant.push_back(samp); samp.clear();
trainer.set_c(10);
decision_function<kernel_type> df = trainer.train(data);
trainer.set_prior(df);
data.relevant.clear();
data.nonrelevant.clear();
samp[2] = 1; data.relevant.push_back(samp); samp.clear();
samp[1] = 1; data.nonrelevant.push_back(samp); samp.clear();
df = trainer.train(data);
matrix<double,0,1> w = sparse_to_dense(df.basis_vectors(0));
dlog << LINFO << trans(w);
DLIB_TEST(w(0) > 0.1);
DLIB_TEST(w(1) < -0.1);
DLIB_TEST(w(2) > 0.1);
}
// ----------------------------------------------------------------------------------------
void dotest1()
......@@ -390,6 +426,7 @@ namespace
test_svmrank_weight_force_dense<true>();
test_svmrank_weight_force_dense<false>();
run_prior_test();
run_prior_sparse_test();
}
} a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment