Commit 016f41ac authored by Davis King's avatar Davis King

Changed the oca interface to allow you to specify that a range of w

elements should be non-negative rather than just being able to say
all or none of them are non-negative.
parent ec2f30b6
......@@ -111,7 +111,7 @@ namespace dlib
typename matrix_type::type operator() (
const oca_problem<matrix_type>& problem,
matrix_type& w,
bool require_nonnegative_w = false
unsigned long num_nonnegative = 0
) const
{
// make sure requires clause is not broken
......@@ -124,6 +124,9 @@ namespace dlib
<< "\n\t this: " << this
);
if (num_nonnegative > static_cast<unsigned long>(problem.get_num_dimensions()))
num_nonnegative = problem.get_num_dimensions();
typedef typename matrix_type::type scalar_type;
typedef typename matrix_type::layout_type layout_type;
typedef typename matrix_type::mem_manager_type mem_manager_type;
......@@ -218,15 +221,16 @@ namespace dlib
eps = 1e-16;
// Note that we warm start this optimization by using the alpha from the last
// iteration as the starting point.
if (require_nonnegative_w)
solve_qp4_using_smo(planes, K, vector_to_matrix(bs), alpha, eps, sub_max_iter);
if (num_nonnegative != 0)
solve_qp4_using_smo(rowm(planes,range(0,num_nonnegative-1)), K, vector_to_matrix(bs), alpha, eps, sub_max_iter);
else
solve_qp_using_smo(K, vector_to_matrix(bs), alpha, eps, sub_max_iter);
// construct the w that minimized the subproblem.
w = -(planes*alpha);
if (require_nonnegative_w)
w = lowerbound(w,0);
// threshold the first num_nonnegative w elements if necessary.
if (num_nonnegative != 0)
set_rowm(w,range(0,num_nonnegative-1)) = lowerbound(rowm(w,range(0,num_nonnegative-1)),0);
for (long i = 0; i < alpha.size(); ++i)
{
......
......@@ -151,7 +151,7 @@ namespace dlib
typename matrix_type::type operator() (
const oca_problem<matrix_type>& problem,
matrix_type& w,
bool require_nonnegative_w = false
unsigned long num_nonnegative = 0
) const;
/*!
requires
......@@ -162,10 +162,10 @@ namespace dlib
- The optimization algorithm runs until problem.optimization_status()
indicates it is time to stop.
- returns the objective value at the solution #w
- if (require_nonnegative_w == true) then
- Adds the constraint that every element of w be non-negative.
Therefore, if this argument is true then #w won't contain any
negative values.
- if (num_nonnegative != 0) then
- Adds the constraint that #w(i) >= 0 for all i < num_nonnegative.
That is, the first num_nonnegative elements of #w will always be
non-negative.
!*/
void set_subproblem_epsilon (
......
......@@ -60,8 +60,8 @@ namespace
oca solver;
// test the version without a non-negativity constraint on w.
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, false);
dlog << LINFO << w;
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, 0);
dlog << LINFO << trans(w);
true_w = -0.5, 0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);
......@@ -69,11 +69,69 @@ namespace
print_spinner();
// test the version with a non-negativity constraint on w.
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, true);
dlog << LINFO << w;
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, 9999);
dlog << LINFO << trans(w);
true_w = 0, 1, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);
print_spinner();
// test the version with a non-negativity constraint on w.
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, 2);
dlog << LINFO << trans(w);
true_w = 0, 1, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);
print_spinner();
// test the version with a non-negativity constraint on w.
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, 1);
dlog << LINFO << trans(w);
true_w = 0, 1, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);
print_spinner();
// switching the labels should change which w weight goes negative.
y.clear();
y.push_back(-1);
y.push_back(+1);
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, 0);
dlog << LINFO << trans(w);
true_w = 0.5, -0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);
print_spinner();
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, 1);
dlog << LINFO << trans(w);
true_w = 0.5, -0.5, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);
print_spinner();
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, 2);
dlog << LINFO << trans(w);
true_w = 1, 0, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);
print_spinner();
solver(make_oca_problem_c_svm<w_type>(2.0, 3.0, vector_to_matrix(x), vector_to_matrix(y), false, 1e-12, 40), w, 5);
dlog << LINFO << trans(w);
true_w = 1, 0, 0;
dlog << LINFO << "error: "<< max(abs(w-true_w));
DLIB_TEST(max(abs(w-true_w)) < 1e-10);
}
} a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment