Commit 016f41ac authored by Davis King's avatar Davis King

Changed the oca interface to allow you to specify that a range of w

elements should be non-negative rather than just being able to say
all or none of them are non-negative.
parent ec2f30b6
...@@ -111,7 +111,7 @@ namespace dlib ...@@ -111,7 +111,7 @@ namespace dlib
typename matrix_type::type operator() ( typename matrix_type::type operator() (
const oca_problem<matrix_type>& problem, const oca_problem<matrix_type>& problem,
matrix_type& w, matrix_type& w,
bool require_nonnegative_w = false unsigned long num_nonnegative = 0
) const ) const
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
...@@ -124,6 +124,9 @@ namespace dlib ...@@ -124,6 +124,9 @@ namespace dlib
<< "\n\t this: " << this << "\n\t this: " << this
); );
if (num_nonnegative > static_cast<unsigned long>(problem.get_num_dimensions()))
num_nonnegative = problem.get_num_dimensions();
typedef typename matrix_type::type scalar_type; typedef typename matrix_type::type scalar_type;
typedef typename matrix_type::layout_type layout_type; typedef typename matrix_type::layout_type layout_type;
typedef typename matrix_type::mem_manager_type mem_manager_type; typedef typename matrix_type::mem_manager_type mem_manager_type;
...@@ -218,15 +221,16 @@ namespace dlib ...@@ -218,15 +221,16 @@ namespace dlib
eps = 1e-16; eps = 1e-16;
// Note that we warm start this optimization by using the alpha from the last // Note that we warm start this optimization by using the alpha from the last
// iteration as the starting point. // iteration as the starting point.
if (require_nonnegative_w) if (num_nonnegative != 0)
solve_qp4_using_smo(planes, K, vector_to_matrix(bs), alpha, eps, sub_max_iter); solve_qp4_using_smo(rowm(planes,range(0,num_nonnegative-1)), K, vector_to_matrix(bs), alpha, eps, sub_max_iter);
else else
solve_qp_using_smo(K, vector_to_matrix(bs), alpha, eps, sub_max_iter); solve_qp_using_smo(K, vector_to_matrix(bs), alpha, eps, sub_max_iter);
// construct the w that minimized the subproblem. // construct the w that minimized the subproblem.
w = -(planes*alpha); w = -(planes*alpha);
if (require_nonnegative_w) // threshold the first num_nonnegative w elements if necessary.
w = lowerbound(w,0); if (num_nonnegative != 0)
set_rowm(w,range(0,num_nonnegative-1)) = lowerbound(rowm(w,range(0,num_nonnegative-1)),0);
for (long i = 0; i < alpha.size(); ++i) for (long i = 0; i < alpha.size(); ++i)
{ {
......
...@@ -151,7 +151,7 @@ namespace dlib ...@@ -151,7 +151,7 @@ namespace dlib
typename matrix_type::type operator() ( typename matrix_type::type operator() (
const oca_problem<matrix_type>& problem, const oca_problem<matrix_type>& problem,
matrix_type& w, matrix_type& w,
bool require_nonnegative_w = false unsigned long num_nonnegative = 0
) const; ) const;
/*! /*!
requires requires
...@@ -162,10 +162,10 @@ namespace dlib ...@@ -162,10 +162,10 @@ namespace dlib
- The optimization algorithm runs until problem.optimization_status() - The optimization algorithm runs until problem.optimization_status()
indicates it is time to stop. indicates it is time to stop.
- returns the objective value at the solution #w - returns the objective value at the solution #w
- if (require_nonnegative_w == true) then - if (num_nonnegative != 0) then
- Adds the constraint that every element of w be non-negative. - Adds the constraint that #w(i) >= 0 for all i < num_nonnegative.
Therefore, if this argument is true then #w won't contain any That is, the first num_nonnegative elements of #w will always be
negative values. non-negative.
!*/ !*/
void set_subproblem_epsilon ( void set_subproblem_epsilon (
......
...@@ -60,8 +60,8 @@ namespace ...@@ -60,8 +60,8 @@ namespace
oca solver; oca solver;
// test the version without a non-negativity constraint on w. // test the version without a non-negativity constraint on w.
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, false); solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, 0);
dlog << LINFO << w; dlog << LINFO << trans(w);
true_w = -0.5, 0.5, 0; true_w = -0.5, 0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w)); dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10); DLIB_TEST(max(abs(w-true_w)) < 1e-10);
...@@ -69,11 +69,69 @@ namespace ...@@ -69,11 +69,69 @@ namespace
print_spinner(); print_spinner();
// test the version with a non-negativity constraint on w. // test the version with a non-negativity constraint on w.
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, true); solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, 9999);
dlog << LINFO << w; dlog << LINFO << trans(w);
true_w = 0, 1, 0; true_w = 0, 1, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w)); dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10); DLIB_TEST(max(abs(w-true_w)) < 1e-10);
print_spinner();
// test the version with a non-negativity constraint on w.
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, 2);
dlog << LINFO << trans(w);
true_w = 0, 1, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);
print_spinner();
// test the version with a non-negativity constraint on w.
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, 1);
dlog << LINFO << trans(w);
true_w = 0, 1, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);
print_spinner();
// switching the labels should change which w weight goes negative.
y.clear();
y.push_back(-1);
y.push_back(+1);
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, 0);
dlog << LINFO << trans(w);
true_w = 0.5, -0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);
print_spinner();
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, 1);
dlog << LINFO << trans(w);
true_w = 0.5, -0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);
print_spinner();
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, 2);
dlog << LINFO << trans(w);
true_w = 1, 0, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);
print_spinner();
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, 5);
dlog << LINFO << trans(w);
true_w = 1, 0, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);
} }
} a; } a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment