Commit 2f25dc55 authored by Davis King's avatar Davis King

Added a test specifically for the svm_c_linear_trainer and for the

oca solver by proxy.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403538
parent 372e6ec2
...@@ -79,6 +79,7 @@ set (tests ...@@ -79,6 +79,7 @@ set (tests
std_vector_c.cpp std_vector_c.cpp
string.cpp string.cpp
svm.cpp svm.cpp
svm_c_linear.cpp
thread_pool.cpp thread_pool.cpp
threads.cpp threads.cpp
timer.cpp timer.cpp
......
...@@ -89,6 +89,7 @@ SRC += statistics.cpp ...@@ -89,6 +89,7 @@ SRC += statistics.cpp
SRC += std_vector_c.cpp SRC += std_vector_c.cpp
SRC += string.cpp SRC += string.cpp
SRC += svm.cpp SRC += svm.cpp
SRC += svm_c_linear.cpp
SRC += thread_pool.cpp SRC += thread_pool.cpp
SRC += threads.cpp SRC += threads.cpp
SRC += timer.cpp SRC += timer.cpp
......
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <dlib/matrix.h>
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include <vector>
#include "../stl_checked.h"
#include "../array.h"
#include "../rand.h"
#include "checkerboard.h"
#include <dlib/statistics.h>
#include "tester.h"
#include <dlib/svm.h>
namespace
{
using namespace test;
using namespace dlib;
using namespace std;
logger dlog("test.svm_c_linear");
typedef matrix<double, 0, 1> sample_type;
typedef std::vector<std::pair<unsigned int, double> > sparse_sample_type;
// ----------------------------------------------------------------------------------------
void get_simple_points (
std::vector<sample_type>& samples,
std::vector<double>& labels
)
{
samples.clear();
labels.clear();
sample_type samp(2);
samp = 0,0;
samples.push_back(samp);
labels.push_back(-1);
samp = 0,1;
samples.push_back(samp);
labels.push_back(-1);
samp = 3,0;
samples.push_back(samp);
labels.push_back(+1);
samp = 3,1;
samples.push_back(samp);
labels.push_back(+1);
}
// ----------------------------------------------------------------------------------------
void get_simple_points_sparse (
std::vector<sparse_sample_type>& samples,
std::vector<double>& labels
)
{
samples.clear();
labels.clear();
sparse_sample_type samp;
samp.push_back(make_pair(0, 0.0));
samp.push_back(make_pair(1, 0.0));
samples.push_back(samp);
labels.push_back(-1);
samp.clear();
samp.push_back(make_pair(0, 0.0));
samp.push_back(make_pair(1, 1.0));
samples.push_back(samp);
labels.push_back(-1);
samp.clear();
samp.push_back(make_pair(0, 3.0));
samp.push_back(make_pair(1, 0.0));
samples.push_back(samp);
labels.push_back(+1);
samp.clear();
samp.push_back(make_pair(0, 3.0));
samp.push_back(make_pair(1, 1.0));
samples.push_back(samp);
labels.push_back(+1);
}
// ----------------------------------------------------------------------------------------
void test_sparse (
)
{
print_spinner();
dlog << LINFO << "test with sparse vectors";
std::vector<sparse_sample_type> samples;
std::vector<double> labels;
sample_type samp;
get_simple_points_sparse(samples,labels);
svm_c_linear_trainer<sparse_linear_kernel<sparse_sample_type> > trainer;
trainer.set_c(1e4);
//trainer.be_verbose();
trainer.set_epsilon(1e-8);
double obj;
decision_function<sparse_linear_kernel<sparse_sample_type> > df = trainer.train(samples, labels, obj);
dlog << LDEBUG << "obj: "<< obj;
DLIB_TEST_MSG(abs(obj - 0.72222222222) < 1e-8, obj);
DLIB_TEST(abs(df(samples[0]) - (-1)) < 1e-6);
DLIB_TEST(abs(df(samples[1]) - (-1)) < 1e-6);
DLIB_TEST(abs(df(samples[2]) - (1)) < 1e-6);
DLIB_TEST(abs(df(samples[3]) - (1)) < 1e-6);
}
// ----------------------------------------------------------------------------------------
void test_dense (
)
{
print_spinner();
dlog << LINFO << "test with dense vectors";
std::vector<sample_type> samples;
std::vector<double> labels;
sample_type samp;
get_simple_points(samples,labels);
svm_c_linear_trainer<linear_kernel<sample_type> > trainer;
trainer.set_c(1e4);
//trainer.be_verbose();
trainer.set_epsilon(1e-8);
double obj;
decision_function<linear_kernel<sample_type> > df = trainer.train(samples, labels, obj);
dlog << LDEBUG << "obj: "<< obj;
DLIB_TEST_MSG(abs(obj - 0.72222222222) < 1e-8, obj);
// There shouldn't be any margin violations since this dataset is so trivial. So that means the objective
// should be exactly the squared norm of the decision plane (times 0.5).
DLIB_TEST_MSG(abs(length_squared(df.basis_vectors(0))*0.5 + df.b*df.b*0.5 - 0.72222222222) < 1e-8,
length_squared(df.basis_vectors(0))*0.5 + df.b*df.b*0.5);
DLIB_TEST(abs(df(samples[0]) - (-1)) < 1e-6);
DLIB_TEST(abs(df(samples[1]) - (-1)) < 1e-6);
DLIB_TEST(abs(df(samples[2]) - (1)) < 1e-6);
DLIB_TEST(abs(df(samples[3]) - (1)) < 1e-6);
}
// ----------------------------------------------------------------------------------------
class svm_c_linear_tester : public tester
{
public:
svm_c_linear_tester (
) :
tester ("test_svm_c_linear",
"Runs tests on the svm_c_linear_trainer.")
{}
void perform_test (
)
{
test_dense();
test_sparse();
}
} a;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment