Commit 983f2cd3 authored by Davis King's avatar Davis King

Added unit tests for the new quadratic program solver.

parent c49f7d99
......@@ -72,6 +72,290 @@ namespace
matrix<double> Q, b;
};
// ----------------------------------------------------------------------------------------
double compute_objective_value (
const matrix<double,0,1>& w,
const matrix<double>& A,
const matrix<double,0,1>& b,
const double C
)
{
return 0.5*dot(w,w) + C*max(trans(A)*w + b);
}
// ----------------------------------------------------------------------------------------
void test_qp4_test1()
{
matrix<double> A(3,2);
A = 1,2,
-3,1,
6,7;
matrix<double,0,1> b(2);
b = 1,
2;
const double C = 2;
matrix<double,0,1> alpha(2), true_alpha(2);
alpha = C/2, C/2;
solve_qp4_using_smo(A, tmp(trans(A)*A), b, alpha, 1e-9, 800);
matrix<double,0,1> w = lowerbound(-A*alpha, 0);
dlog << LINFO << "*******************************************************";
dlog << LINFO << "w: " << trans(w);
dlog << LINFO << "computed obj: "<< compute_objective_value(w,A,b,C);
w = 0;
dlog << LINFO << "with true w obj: "<< compute_objective_value(w,A,b,C);
dlog << LINFO << "alpha: " << trans(alpha);
true_alpha = 0, 2;
dlog << LINFO << "true alpha: "<< trans(true_alpha);
dlog << LINFO << "alpha error: "<< max(abs(alpha-true_alpha));
DLIB_TEST(max(abs(alpha-true_alpha)) < 1e-9);
}
// ----------------------------------------------------------------------------------------
void test_qp4_test2()
{
matrix<double> A(3,2);
A = 1,2,
3,-1,
6,7;
matrix<double,0,1> b(2);
b = 1,
2;
const double C = 2;
matrix<double,0,1> alpha(2), true_alpha(2);
alpha = C/2, C/2;
solve_qp4_using_smo(A, tmp(trans(A)*A), b, alpha, 1e-9, 800);
matrix<double,0,1> w = lowerbound(-A*alpha, 0);
dlog << LINFO << "*******************************************************";
dlog << LINFO << "w: " << trans(w);
dlog << LINFO << "computed obj: "<< compute_objective_value(w,A,b,C);
w = 0, 0.25, 0;
dlog << LINFO << "with true w obj: "<< compute_objective_value(w,A,b,C);
dlog << LINFO << "alpha: " << trans(alpha);
true_alpha = 0.43750, 1.56250;
dlog << LINFO << "true alpha: "<< trans(true_alpha);
dlog << LINFO << "alpha error: "<< max(abs(alpha-true_alpha));
DLIB_TEST(max(abs(alpha-true_alpha)) < 1e-9);
}
// ----------------------------------------------------------------------------------------
void test_qp4_test3()
{
matrix<double> A(3,2);
A = 1,2,
-3,-1,
6,7;
matrix<double,0,1> b(2);
b = 1,
2;
const double C = 2;
matrix<double,0,1> alpha(2), true_alpha(2);
alpha = C/2, C/2;
solve_qp4_using_smo(A, tmp(trans(A)*A), b, alpha, 1e-9, 800);
matrix<double,0,1> w = lowerbound(-A*alpha, 0);
dlog << LINFO << "*******************************************************";
dlog << LINFO << "w: " << trans(w);
dlog << LINFO << "computed obj: "<< compute_objective_value(w,A,b,C);
w = 0, 2, 0;
dlog << LINFO << "with true w obj: "<< compute_objective_value(w,A,b,C);
dlog << LINFO << "alpha: " << trans(alpha);
true_alpha = 0, 2;
dlog << LINFO << "true alpha: "<< trans(true_alpha);
dlog << LINFO << "alpha error: "<< max(abs(alpha-true_alpha));
DLIB_TEST(max(abs(alpha-true_alpha)) < 1e-9);
}
// ----------------------------------------------------------------------------------------
void test_qp4_test5()
{
matrix<double> A(3,3);
A = 1,2,4,
3,1,6,
6,7,-2;
matrix<double,0,1> b(3);
b = 1,
2,
3;
const double C = 2;
matrix<double,0,1> alpha(3), true_alpha(3);
alpha = C/2, C/2, 0;
solve_qp4_using_smo(A, tmp(trans(A)*A), b, alpha, 1e-9, 800);
matrix<double,0,1> w = lowerbound(-A*alpha, 0);
dlog << LINFO << "*******************************************************";
dlog << LINFO << "w: " << trans(w);
dlog << LINFO << "computed obj: "<< compute_objective_value(w,A,b,C);
w = 0, 0, 0.11111111111111111111;
dlog << LINFO << "with true w obj: "<< compute_objective_value(w,A,b,C);
dlog << LINFO << "alpha: " << trans(alpha);
true_alpha = 0, 0.432098765432099, 1.567901234567901;
dlog << LINFO << "true alpha: "<< trans(true_alpha);
dlog << LINFO << "alpha error: "<< max(abs(alpha-true_alpha));
DLIB_TEST(max(abs(alpha-true_alpha)) < 1e-9);
}
// ----------------------------------------------------------------------------------------
void test_qp4_test4()
{
matrix<double> A(3,2);
A = 1,2,
3,1,
6,7;
matrix<double,0,1> b(2);
b = 1,
2;
const double C = 2;
matrix<double,0,1> alpha(2), true_alpha(2);
alpha = C/2, C/2;
solve_qp4_using_smo(A, tmp(trans(A)*A), b, alpha, 1e-9, 800);
matrix<double,0,1> w = lowerbound(-A*alpha, 0);
dlog << LINFO << "*******************************************************";
dlog << LINFO << "w: " << trans(w);
dlog << LINFO << "computed obj: "<< compute_objective_value(w,A,b,C);
w = 0, 0, 0;
dlog << LINFO << "with true w obj: "<< compute_objective_value(w,A,b,C);
dlog << LINFO << "alpha: " << trans(alpha);
true_alpha = 0, 2;
dlog << LINFO << "true alpha: "<< trans(true_alpha);
dlog << LINFO << "alpha error: "<< max(abs(alpha-true_alpha));
DLIB_TEST(max(abs(alpha-true_alpha)) < 1e-9);
}
void test_qp4_test6()
{
matrix<double> A(3,3);
A = 1,2,4,
3,1,6,
6,7,-2;
matrix<double,0,1> b(3);
b = -1,
-2,
-3;
const double C = 2;
matrix<double,0,1> alpha(3), true_alpha(3);
alpha = C/2, C/2, 0;
solve_qp4_using_smo(A, tmp(trans(A)*A), b, alpha, 1e-9, 800);
matrix<double,0,1> w = lowerbound(-A*alpha, 0);
dlog << LINFO << "*******************************************************";
dlog << LINFO << "w: " << trans(w);
dlog << LINFO << "computed obj: "<< compute_objective_value(w,A,b,C);
w = 0, 0, 0;
dlog << LINFO << "with true w obj: "<< compute_objective_value(w,A,b,C);
dlog << LINFO << "alpha: " << trans(alpha);
true_alpha = 2, 0, 0;
dlog << LINFO << "true alpha: "<< trans(true_alpha);
dlog << LINFO << "alpha error: "<< max(abs(alpha-true_alpha));
DLIB_TEST(max(abs(alpha-true_alpha)) < 1e-9);
}
void test_qp4_test7()
{
matrix<double> A(3,3);
A = -1,2,4,
-3,1,6,
-6,7,-2;
matrix<double,0,1> b(3);
b = -1,
-2,
3;
matrix<double> Q(3,3);
Q = 4,-5,6,
1,-4,2,
-9,-4,5;
Q = Q*trans(Q);
const double C = 2;
matrix<double,0,1> alpha(3), true_alpha(3);
alpha = C/2, C/2, 0;
solve_qp4_using_smo(A, Q, b, alpha, 1e-9, 800);
dlog << LINFO << "*******************************************************";
dlog << LINFO << "alpha: " << trans(alpha);
true_alpha = 0, 2, 0;
dlog << LINFO << "true alpha: "<< trans(true_alpha);
dlog << LINFO << "alpha error: "<< max(abs(alpha-true_alpha));
DLIB_TEST(max(abs(alpha-true_alpha)) < 1e-9);
}
// ----------------------------------------------------------------------------------------
void test_solve_qp4_using_smo()
{
test_qp4_test1();
test_qp4_test2();
test_qp4_test3();
test_qp4_test4();
test_qp4_test5();
test_qp4_test6();
test_qp4_test7();
}
// ----------------------------------------------------------------------------------------
class opt_qp_solver_tester : public tester
......@@ -98,6 +382,10 @@ namespace
void perform_test(
)
{
print_spinner();
test_solve_qp4_using_smo();
print_spinner();
++thetime;
typedef matrix<double,0,1> sample_type;
//dlog << LINFO << "time seed: " << thetime;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment