Commit 565e63f2 authored by Davis King's avatar Davis King

Added tests for the rls object

parent c3319182
...@@ -86,6 +86,7 @@ set (tests ...@@ -86,6 +86,7 @@ set (tests
rand.cpp rand.cpp
read_write_mutex.cpp read_write_mutex.cpp
reference_counter.cpp reference_counter.cpp
rls.cpp
scan_image.cpp scan_image.cpp
sequence.cpp sequence.cpp
sequence_labeler.cpp sequence_labeler.cpp
......
...@@ -101,6 +101,7 @@ SRC += queue.cpp ...@@ -101,6 +101,7 @@ SRC += queue.cpp
SRC += rand.cpp SRC += rand.cpp
SRC += read_write_mutex.cpp SRC += read_write_mutex.cpp
SRC += reference_counter.cpp SRC += reference_counter.cpp
SRC += rls.cpp
SRC += scan_image.cpp SRC += scan_image.cpp
SRC += sequence.cpp SRC += sequence.cpp
SRC += sequence_labeler.cpp SRC += sequence_labeler.cpp
......
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include <cmath>
#include <dlib/svm.h>
#include "tester.h"
namespace
{
using namespace test;
using namespace dlib;
using namespace std;
logger dlog("test.rls");
void test_rls()
{
dlib::rand rnd;
running_stats<double> rs1, rs2, rs3, rs4;
for (int k = 0; k < 2; ++k)
{
for (long num_vars = 1; num_vars < 4; ++num_vars)
{
print_spinner();
for (long size = 1; size < 300; ++size)
{
{
matrix<double> X = randm(size,num_vars,rnd);
matrix<double,0,1> Y = randm(size,1,rnd);
const double C = 1000;
const double forget_factor = 1.0;
rls r(forget_factor, C);
for (long i = 0; i < Y.size(); ++i)
{
r.train(trans(rowm(X,i)), Y(i));
}
matrix<double> w = pinv(1.0/C*identity_matrix<double>(X.nc()) + trans(X)*X)*trans(X)*Y;
rs1.add(length(r.get_w() - w));
}
{
matrix<double> X = randm(size,num_vars,rnd);
matrix<double> Y = colm(X,0)*10;
const double C = 1000000;
const double forget_factor = 1.0;
rls r(forget_factor, C);
for (long i = 0; i < Y.size(); ++i)
{
r.train(trans(rowm(X,i)), Y(i));
}
matrix<double> w = pinv(1.0/C*identity_matrix<double>(X.nc()) + trans(X)*X)*trans(X)*Y;
rs2.add(length(r.get_w() - w));
}
{
matrix<double> X = join_rows(randm(size,num_vars,rnd)-0.5, ones_matrix<double>(size,1));
matrix<double> Y = uniform_matrix<double>(size,1,10);
const double C = 1e7;
const double forget_factor = 1.0;
matrix<double> w = pinv(1.0/C*identity_matrix<double>(X.nc()) + trans(X)*X)*trans(X)*Y;
rls r(forget_factor, C);
for (long i = 0; i < Y.size(); ++i)
{
r.train(trans(rowm(X,i)), Y(i));
rs3.add(std::abs(r(trans(rowm(X,i))) - 10));
}
}
{
matrix<double> X = randm(size,num_vars,rnd)-0.5;
matrix<double> Y = colm(X,0)*10;
const double C = 1e6;
const double forget_factor = 0.7;
rls r(forget_factor, C);
DLIB_TEST(std::abs(r.get_c() - C) < 1e-10);
DLIB_TEST(std::abs(r.get_forget_factor() - forget_factor) < 1e-15);
DLIB_TEST(r.get_w().size() == 0);
for (long i = 0; i < Y.size(); ++i)
{
r.train(trans(rowm(X,i)), Y(i));
rs4.add(std::abs(r(trans(rowm(X,i))) - X(i,0)*10));
}
DLIB_TEST(r.get_w().size() == num_vars);
decision_function<linear_kernel<matrix<double,0,1> > > df = r.get_decision_function();
DLIB_TEST(std::abs(df(trans(rowm(X,0))) - r(trans(rowm(X,0)))) < 1e-15);
}
}
}
}
dlog << LINFO << "rs1.mean(): " << rs1.mean();
dlog << LINFO << "rs2.mean(): " << rs2.mean();
dlog << LINFO << "rs3.mean(): " << rs3.mean();
dlog << LINFO << "rs4.mean(): " << rs4.mean();
dlog << LINFO << "rs1.max(): " << rs1.max();
dlog << LINFO << "rs2.max(): " << rs2.max();
dlog << LINFO << "rs3.max(): " << rs3.max();
dlog << LINFO << "rs4.max(): " << rs4.max();
DLIB_TEST_MSG(rs1.mean() < 1e-10, rs1.mean());
DLIB_TEST_MSG(rs2.mean() < 1e-9, rs2.mean());
DLIB_TEST_MSG(rs3.mean() < 1e-6, rs3.mean());
DLIB_TEST_MSG(rs4.mean() < 1e-6, rs4.mean());
DLIB_TEST_MSG(rs1.max() < 1e-10, rs1.max());
DLIB_TEST_MSG(rs2.max() < 1e-6, rs2.max());
DLIB_TEST_MSG(rs3.max() < 0.001, rs3.max());
DLIB_TEST_MSG(rs4.max() < 0.01, rs4.max());
}
class rls_tester : public tester
{
public:
rls_tester (
) :
tester ("test_rls",
"Runs tests on the rls component.")
{}
void perform_test (
)
{
test_rls();
}
} a;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment