Commit 29f22685 authored by Davis King's avatar Davis King

Fixed set_prior() so it works with sparse vectors in addition to dense vectors.

parent d7f207f2
...@@ -37,13 +37,15 @@ namespace dlib ...@@ -37,13 +37,15 @@ namespace dlib
const std::vector<ranking_pair<sample_type> >& samples_, const std::vector<ranking_pair<sample_type> >& samples_,
const bool be_verbose_, const bool be_verbose_,
const scalar_type eps_, const scalar_type eps_,
const unsigned long max_iter const unsigned long max_iter,
const unsigned long dims_
) : ) :
samples(samples_), samples(samples_),
C(C_), C(C_),
be_verbose(be_verbose_), be_verbose(be_verbose_),
eps(eps_), eps(eps_),
max_iterations(max_iter) max_iterations(max_iter),
dims(dims_)
{ {
} }
...@@ -56,7 +58,7 @@ namespace dlib ...@@ -56,7 +58,7 @@ namespace dlib
virtual long get_num_dimensions ( virtual long get_num_dimensions (
) const ) const
{ {
return max_index_plus_one(samples); return dims;
} }
virtual bool optimization_status ( virtual bool optimization_status (
...@@ -173,6 +175,7 @@ namespace dlib ...@@ -173,6 +175,7 @@ namespace dlib
const bool be_verbose; const bool be_verbose;
const scalar_type eps; const scalar_type eps;
const unsigned long max_iterations; const unsigned long max_iterations;
const unsigned long dims;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -187,11 +190,12 @@ namespace dlib ...@@ -187,11 +190,12 @@ namespace dlib
const std::vector<ranking_pair<sample_type> >& samples, const std::vector<ranking_pair<sample_type> >& samples,
const bool be_verbose, const bool be_verbose,
const scalar_type eps, const scalar_type eps,
const unsigned long max_iterations const unsigned long max_iterations,
const unsigned long dims
) )
{ {
return oca_problem_ranking_svm<matrix_type, sample_type>( return oca_problem_ranking_svm<matrix_type, sample_type>(
C, samples, be_verbose, eps, max_iterations); C, samples, be_verbose, eps, max_iterations, dims);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -346,7 +350,7 @@ namespace dlib ...@@ -346,7 +350,7 @@ namespace dlib
<< "\n\t this: " << this << "\n\t this: " << this
); );
prior = prior_.basis_vectors(0); prior = sparse_to_dense(prior_.basis_vectors(0));
learn_nonnegative_weights = false; learn_nonnegative_weights = false;
last_weight_1 = false; last_weight_1 = false;
} }
...@@ -421,13 +425,29 @@ namespace dlib ...@@ -421,13 +425,29 @@ namespace dlib
<< "\n\t prior.size(): " << prior.size() << "\n\t prior.size(): " << prior.size()
); );
} }
solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations), const unsigned long dims = std::max(num_dims, (unsigned long)prior.size());
// In the case of sparse sample vectors, it is possible that the input
// vector dimensionality is larger than the prior vector dimensionality.
// We need to check for this case and pad prior with zeros if it is the
// case.
if ((unsigned long)prior.size() < dims)
{
matrix<scalar_type,0,1> prior_temp = join_cols(prior, zeros_matrix<scalar_type>(dims-prior.size(),1));
solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations, dims),
w,
prior_temp);
}
else
{
solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations, dims),
w, w,
prior); prior);
} }
}
else else
{ {
solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations), solver( make_oca_problem_ranking_svm<w_type>(C, samples, verbose, eps, max_iterations, num_dims),
w, w,
num_nonnegative, num_nonnegative,
force_weight_1_idx); force_weight_1_idx);
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <string> #include <string>
#include <cstdlib> #include <cstdlib>
#include <ctime> #include <ctime>
#include <map>
#include "tester.h" #include "tester.h"
...@@ -107,6 +108,41 @@ namespace ...@@ -107,6 +108,41 @@ namespace
DLIB_TEST(df.basis_vectors(0)(2) > 0); DLIB_TEST(df.basis_vectors(0)(2) > 0);
} }
// ----------------------------------------------------------------------------------------
void run_prior_sparse_test()
{
print_spinner();
typedef std::map<unsigned long,double> sample_type;
typedef sparse_linear_kernel<sample_type> kernel_type;
svm_rank_trainer<kernel_type> trainer;
ranking_pair<sample_type> data;
sample_type samp;
samp[0] = 1; data.relevant.push_back(samp); samp.clear();
samp[1] = 1; data.nonrelevant.push_back(samp); samp.clear();
trainer.set_c(10);
decision_function<kernel_type> df = trainer.train(data);
trainer.set_prior(df);
data.relevant.clear();
data.nonrelevant.clear();
samp[2] = 1; data.relevant.push_back(samp); samp.clear();
samp[1] = 1; data.nonrelevant.push_back(samp); samp.clear();
df = trainer.train(data);
matrix<double,0,1> w = sparse_to_dense(df.basis_vectors(0));
dlog << LINFO << trans(w);
DLIB_TEST(w(0) > 0.1);
DLIB_TEST(w(1) < -0.1);
DLIB_TEST(w(2) > 0.1);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
void dotest1() void dotest1()
...@@ -390,6 +426,7 @@ namespace ...@@ -390,6 +426,7 @@ namespace
test_svmrank_weight_force_dense<true>(); test_svmrank_weight_force_dense<true>();
test_svmrank_weight_force_dense<false>(); test_svmrank_weight_force_dense<false>();
run_prior_test(); run_prior_test();
run_prior_sparse_test();
} }
} a; } a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment