Commit 721597f2 authored by Davis King's avatar Davis King

Made set_prior() work with sparse vectors.

parent 05c0b373
...@@ -44,7 +44,8 @@ namespace dlib ...@@ -44,7 +44,8 @@ namespace dlib
const in_scalar_vector_type& labels_, const in_scalar_vector_type& labels_,
const bool be_verbose_, const bool be_verbose_,
const scalar_type eps_, const scalar_type eps_,
const unsigned long max_iter const unsigned long max_iter,
const unsigned long dims_
) : ) :
samples(samples_), samples(samples_),
labels(labels_), labels(labels_),
...@@ -53,7 +54,8 @@ namespace dlib ...@@ -53,7 +54,8 @@ namespace dlib
Cneg(C_neg/C), Cneg(C_neg/C),
be_verbose(be_verbose_), be_verbose(be_verbose_),
eps(eps_), eps(eps_),
max_iterations(max_iter) max_iterations(max_iter),
dims(dims_)
{ {
dot_prods.resize(samples.size()); dot_prods.resize(samples.size());
is_first_call = true; is_first_call = true;
...@@ -69,7 +71,7 @@ namespace dlib ...@@ -69,7 +71,7 @@ namespace dlib
) const ) const
{ {
// plus 1 for the bias term // plus 1 for the bias term
return max_index_plus_one(samples) + 1; return dims + 1;
} }
virtual bool optimization_status ( virtual bool optimization_status (
...@@ -300,6 +302,7 @@ namespace dlib ...@@ -300,6 +302,7 @@ namespace dlib
const bool be_verbose; const bool be_verbose;
const scalar_type eps; const scalar_type eps;
const unsigned long max_iterations; const unsigned long max_iterations;
const unsigned long dims;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -317,11 +320,12 @@ namespace dlib ...@@ -317,11 +320,12 @@ namespace dlib
const in_scalar_vector_type& labels, const in_scalar_vector_type& labels,
const bool be_verbose, const bool be_verbose,
const scalar_type eps, const scalar_type eps,
const unsigned long max_iterations const unsigned long max_iterations,
const unsigned long dims
) )
{ {
return oca_problem_c_svm<matrix_type, in_sample_vector_type, in_scalar_vector_type>( return oca_problem_c_svm<matrix_type, in_sample_vector_type, in_scalar_vector_type>(
C_pos, C_neg, samples, labels, be_verbose, eps, max_iterations); C_pos, C_neg, samples, labels, be_verbose, eps, max_iterations, dims);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -478,7 +482,8 @@ namespace dlib ...@@ -478,7 +482,8 @@ namespace dlib
<< "\n\t this: " << this << "\n\t this: " << this
); );
prior = join_cols(prior_.basis_vectors(0), mat((scalar_type)prior_.b)); prior = sparse_to_dense(prior_.basis_vectors(0));
prior_b = prior_.b;
learn_nonnegative_weights = false; learn_nonnegative_weights = false;
last_weight_1 = false; last_weight_1 = false;
} }
...@@ -631,7 +636,7 @@ namespace dlib ...@@ -631,7 +636,7 @@ namespace dlib
if (is_matrix<sample_type>::value) if (is_matrix<sample_type>::value)
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_CASSERT(num_dims+1 == (unsigned long)prior.size(), DLIB_CASSERT(num_dims == (unsigned long)prior.size(),
"\t decision_function svm_c_linear_trainer::train(x,y)" "\t decision_function svm_c_linear_trainer::train(x,y)"
<< "\n\t The dimension of the training vectors must match the dimension of\n" << "\n\t The dimension of the training vectors must match the dimension of\n"
<< "\n\t those used to create the prior." << "\n\t those used to create the prior."
...@@ -639,15 +644,24 @@ namespace dlib ...@@ -639,15 +644,24 @@ namespace dlib
<< "\n\t prior.size(): " << prior.size() << "\n\t prior.size(): " << prior.size()
); );
} }
const unsigned long dims = std::max(num_dims, (unsigned long)prior.size());
// In the case of sparse sample vectors, it is possible that the input
// vector dimensionality is larger than the prior vector dimensionality.
// We need to check for this case and pad prior with zeros if it is the
// case.
matrix<scalar_type,0,1> prior_temp = join_cols(join_cols(prior,
zeros_matrix<scalar_type>(dims-prior.size(),1)),
mat(prior_b));
svm_objective = solver( svm_objective = solver(
make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps, max_iterations), make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps, max_iterations, dims),
w, w,
prior); prior_temp);
} }
else else
{ {
svm_objective = solver( svm_objective = solver(
make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps, max_iterations), make_oca_problem_c_svm<w_type>(Cpos, Cneg, x, y, verbose, eps, max_iterations, num_dims),
w, w,
num_nonnegative, num_nonnegative,
force_weight_1_idx); force_weight_1_idx);
...@@ -678,6 +692,7 @@ namespace dlib ...@@ -678,6 +692,7 @@ namespace dlib
bool learn_nonnegative_weights; bool learn_nonnegative_weights;
bool last_weight_1; bool last_weight_1;
matrix<scalar_type,0,1> prior; matrix<scalar_type,0,1> prior;
scalar_type prior_b;
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -66,42 +66,42 @@ namespace ...@@ -66,42 +66,42 @@ namespace
oca solver; oca solver;
// test the version without a non-negativity constraint on w. // test the version without a non-negativity constraint on w.
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40), w, 0); solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 0);
dlog << LINFO << trans(w); dlog << LINFO << trans(w);
true_w = -0.5, 0.5, 0; true_w = -0.5, 0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w)); dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10); DLIB_TEST(max(abs(w-true_w)) < 1e-10);
w_type prior = true_w; w_type prior = true_w;
solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40), w, prior); solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior);
dlog << LINFO << trans(w); dlog << LINFO << trans(w);
true_w = -0.5, 0.5, 0; true_w = -0.5, 0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w)); dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10); DLIB_TEST(max(abs(w-true_w)) < 1e-10);
prior = 0,0,0; prior = 0,0,0;
solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40), w, prior); solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior);
dlog << LINFO << trans(w); dlog << LINFO << trans(w);
true_w = -0.5, 0.5, 0; true_w = -0.5, 0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w)); dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10); DLIB_TEST(max(abs(w-true_w)) < 1e-10);
prior = -1,1,0; prior = -1,1,0;
solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40), w, prior); solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior);
dlog << LINFO << trans(w); dlog << LINFO << trans(w);
true_w = -1.0, 1.0, 0; true_w = -1.0, 1.0, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w)); dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10); DLIB_TEST(max(abs(w-true_w)) < 1e-10);
prior = -0.2,0.2,0; prior = -0.2,0.2,0;
solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40), w, prior); solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior);
dlog << LINFO << trans(w); dlog << LINFO << trans(w);
true_w = -0.5, 0.5, 0; true_w = -0.5, 0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w)); dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10); DLIB_TEST(max(abs(w-true_w)) < 1e-10);
prior = -10.2,-1,0; prior = -10.2,-1,0;
solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40), w, prior); solver(make_oca_problem_c_svm<w_type>(20.0, 30.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, prior);
dlog << LINFO << trans(w); dlog << LINFO << trans(w);
true_w = -10.2, -1.0, 0; true_w = -10.2, -1.0, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w)); dlog << LINFO << "error: "<< max(abs(w-true_w));
...@@ -110,7 +110,7 @@ namespace ...@@ -110,7 +110,7 @@ namespace
print_spinner(); print_spinner();
// test the version with a non-negativity constraint on w. // test the version with a non-negativity constraint on w.
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40), w, 9999); solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 9999);
dlog << LINFO << trans(w); dlog << LINFO << trans(w);
true_w = 0, 1, 0; true_w = 0, 1, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w)); dlog << LINFO << "error: "<< max(abs(w-true_w));
...@@ -126,7 +126,7 @@ namespace ...@@ -126,7 +126,7 @@ namespace
print_spinner(); print_spinner();
// test the version with a non-negativity constraint on w. // test the version with a non-negativity constraint on w.
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40), w, 2); solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 2);
dlog << LINFO << trans(w); dlog << LINFO << trans(w);
true_w = 0, 1, 0; true_w = 0, 1, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w)); dlog << LINFO << "error: "<< max(abs(w-true_w));
...@@ -136,7 +136,7 @@ namespace ...@@ -136,7 +136,7 @@ namespace
// test the version with a non-negativity constraint on w. // test the version with a non-negativity constraint on w.
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40), w, 1); solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 1);
dlog << LINFO << trans(w); dlog << LINFO << trans(w);
true_w = 0, 1, 0; true_w = 0, 1, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w)); dlog << LINFO << "error: "<< max(abs(w-true_w));
...@@ -151,7 +151,7 @@ namespace ...@@ -151,7 +151,7 @@ namespace
y.push_back(+1); y.push_back(+1);
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40), w, 0); solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 0);
dlog << LINFO << trans(w); dlog << LINFO << trans(w);
true_w = 0.5, -0.5, 0; true_w = 0.5, -0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w)); dlog << LINFO << "error: "<< max(abs(w-true_w));
...@@ -159,7 +159,7 @@ namespace ...@@ -159,7 +159,7 @@ namespace
print_spinner(); print_spinner();
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40), w, 1); solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 1);
dlog << LINFO << trans(w); dlog << LINFO << trans(w);
true_w = 0.5, -0.5, 0; true_w = 0.5, -0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w)); dlog << LINFO << "error: "<< max(abs(w-true_w));
...@@ -167,7 +167,7 @@ namespace ...@@ -167,7 +167,7 @@ namespace
print_spinner(); print_spinner();
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40), w, 2); solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 2);
dlog << LINFO << trans(w); dlog << LINFO << trans(w);
true_w = 1, 0, 0; true_w = 1, 0, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w)); dlog << LINFO << "error: "<< max(abs(w-true_w));
...@@ -175,7 +175,7 @@ namespace ...@@ -175,7 +175,7 @@ namespace
print_spinner(); print_spinner();
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40), w, 5); solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, mat(x), mat(y), false, 1e-12, 40, max_index_plus_one(x)), w, 5);
dlog << LINFO << trans(w); dlog << LINFO << trans(w);
true_w = 1, 0, 0; true_w = 1, 0, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w)); dlog << LINFO << "error: "<< max(abs(w-true_w));
......
...@@ -70,6 +70,44 @@ namespace ...@@ -70,6 +70,44 @@ namespace
DLIB_TEST(df.basis_vectors(0)(2) > 0); DLIB_TEST(df.basis_vectors(0)(2) > 0);
} }
void run_prior_sparse_test()
{
typedef std::map<unsigned long,double> sample_type;
typedef sparse_linear_kernel<sample_type> kernel_type;
svm_c_linear_trainer<kernel_type> trainer;
std::vector<sample_type> samples;
std::vector<double> labels;
sample_type samp;
samp[0] = 1; samples.push_back(samp); labels.push_back(+1); samp.clear();
samp[1] = 1; samples.push_back(samp); labels.push_back(-1); samp.clear();
trainer.set_c(10);
decision_function<kernel_type> df = trainer.train(samples, labels);
trainer.set_prior(df);
samples.clear();
labels.clear();
samp[2] = 1; samples.push_back(samp); labels.push_back(+1); samp.clear();
samp[1] = 1; samples.push_back(samp); labels.push_back(-1); samp.clear();
df = trainer.train(samples, labels);
matrix<double,1,2> rs = test_binary_decision_function(df, samples, labels);
dlog << LINFO << rs;
DLIB_TEST(rs(0) == 1);
DLIB_TEST(rs(1) == 1);
matrix<double,0,1> w = sparse_to_dense(df.basis_vectors(0));
dlog << LINFO << trans(w);
DLIB_TEST(w(0) > 0.1);
DLIB_TEST(w(1) < -0.1);
DLIB_TEST(w(2) > 0.1);
}
void get_simple_points ( void get_simple_points (
std::vector<sample_type>& samples, std::vector<sample_type>& samples,
std::vector<double>& labels std::vector<double>& labels
...@@ -255,6 +293,7 @@ namespace ...@@ -255,6 +293,7 @@ namespace
test_dense(); test_dense();
test_sparse(); test_sparse();
run_prior_test(); run_prior_test();
run_prior_sparse_test();
// test mixed sparse and dense dot products // test mixed sparse and dense dot products
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment