Commit 0250dd00 authored by Davis King's avatar Davis King

Added unit tests for the assignment learning stuff

parent 44c79bcb
...@@ -14,6 +14,7 @@ set (tests ...@@ -14,6 +14,7 @@ set (tests
any_function.cpp any_function.cpp
array2d.cpp array2d.cpp
array.cpp array.cpp
assignment_learning.cpp
base64.cpp base64.cpp
bayes_nets.cpp bayes_nets.cpp
bigint.cpp bigint.cpp
......
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include "tester.h"
#include <dlib/svm_threaded.h>
#include <dlib/rand.h>
namespace
{
using namespace test;
using namespace dlib;
using namespace std;
logger dlog("test.assignment_learning");
// ----------------------------------------------------------------------------------------
typedef matrix<double,3,1> lhs_element;
typedef matrix<double,3,1> rhs_element;
// ----------------------------------------------------------------------------------------
struct feature_extractor_dense
{
typedef matrix<double,4,1> feature_vector_type;
typedef ::lhs_element lhs_element;
typedef ::rhs_element rhs_element;
unsigned long num_features() const
{
return 4;
}
void get_features (
const lhs_element& left,
const rhs_element& right,
feature_vector_type& feats
) const
{
feats = join_cols(squared(left - right), ones_matrix<double>(1,1));
}
};
void serialize (const feature_extractor_dense& , std::ostream& ) {}
void deserialize (feature_extractor_dense& , std::istream& ) {}
// ----------------------------------------------------------------------------------------
struct feature_extractor_sparse
{
typedef std::vector<std::pair<unsigned long,double> > feature_vector_type;
typedef ::lhs_element lhs_element;
typedef ::rhs_element rhs_element;
unsigned long num_features() const
{
return 4;
}
void get_features (
const lhs_element& left,
const rhs_element& right,
feature_vector_type& feats
) const
{
feats.clear();
feats.push_back(make_pair(0,squared(left-right)(0)));
feats.push_back(make_pair(1,squared(left-right)(1)));
feats.push_back(make_pair(2,squared(left-right)(2)));
feats.push_back(make_pair(3,1.0));
}
};
void serialize (const feature_extractor_sparse& , std::ostream& ) {}
void deserialize (feature_extractor_sparse& , std::istream& ) {}
// ----------------------------------------------------------------------------------------
typedef std::pair<std::vector<lhs_element>, std::vector<rhs_element> > sample_type;
typedef std::vector<long> label_type;
// ----------------------------------------------------------------------------------------
void make_data (
std::vector<sample_type>& samples,
std::vector<label_type>& labels
)
{
lhs_element a, b, c, d;
a = 1,0,0;
b = 0,1,0;
c = 0,0,1;
d = 0,1,1;
std::vector<lhs_element> lhs;
std::vector<rhs_element> rhs;
label_type label;
lhs.push_back(a);
lhs.push_back(b);
lhs.push_back(c);
rhs.push_back(b);
rhs.push_back(a);
rhs.push_back(c);
label.push_back(1);
label.push_back(0);
label.push_back(2);
samples.push_back(make_pair(lhs,rhs));
labels.push_back(label);
lhs.clear();
rhs.clear();
label.clear();
lhs.push_back(a);
lhs.push_back(b);
lhs.push_back(c);
rhs.push_back(c);
rhs.push_back(b);
rhs.push_back(a);
rhs.push_back(d);
label.push_back(2);
label.push_back(1);
label.push_back(0);
samples.push_back(make_pair(lhs,rhs));
labels.push_back(label);
lhs.clear();
rhs.clear();
label.clear();
lhs.push_back(a);
lhs.push_back(b);
lhs.push_back(c);
rhs.push_back(c);
rhs.push_back(a);
rhs.push_back(d);
label.push_back(1);
label.push_back(-1);
label.push_back(0);
samples.push_back(make_pair(lhs,rhs));
labels.push_back(label);
lhs.clear();
rhs.clear();
label.clear();
lhs.push_back(d);
lhs.push_back(b);
lhs.push_back(c);
label.push_back(-1);
label.push_back(-1);
label.push_back(-1);
samples.push_back(make_pair(lhs,rhs));
labels.push_back(label);
lhs.clear();
rhs.clear();
label.clear();
samples.push_back(make_pair(lhs,rhs));
labels.push_back(label);
}
// ----------------------------------------------------------------------------------------
void make_data_force (
std::vector<sample_type>& samples,
std::vector<label_type>& labels
)
{
lhs_element a, b, c, d;
a = 1,0,0;
b = 0,1,0;
c = 0,0,1;
d = 0,1,1;
std::vector<lhs_element> lhs;
std::vector<rhs_element> rhs;
label_type label;
lhs.push_back(a);
lhs.push_back(b);
lhs.push_back(c);
rhs.push_back(b);
rhs.push_back(a);
rhs.push_back(c);
label.push_back(1);
label.push_back(0);
label.push_back(2);
samples.push_back(make_pair(lhs,rhs));
labels.push_back(label);
lhs.clear();
rhs.clear();
label.clear();
lhs.push_back(a);
lhs.push_back(b);
lhs.push_back(c);
rhs.push_back(c);
rhs.push_back(b);
rhs.push_back(a);
rhs.push_back(d);
label.push_back(2);
label.push_back(1);
label.push_back(0);
samples.push_back(make_pair(lhs,rhs));
labels.push_back(label);
lhs.clear();
rhs.clear();
label.clear();
lhs.push_back(a);
lhs.push_back(c);
rhs.push_back(c);
rhs.push_back(a);
label.push_back(1);
label.push_back(0);
samples.push_back(make_pair(lhs,rhs));
labels.push_back(label);
lhs.clear();
rhs.clear();
label.clear();
samples.push_back(make_pair(lhs,rhs));
labels.push_back(label);
}
// ----------------------------------------------------------------------------------------
template <typename fe_type, typename F>
void test1(F make_data, bool force_assignment)
{
print_spinner();
std::vector<sample_type> samples;
std::vector<label_type> labels;
make_data(samples, labels);
make_data(samples, labels);
make_data(samples, labels);
randomize_samples(samples, labels);
structural_assignment_trainer<fe_type> trainer;
DLIB_TEST(trainer.forces_assignment() == false);
DLIB_TEST(trainer.get_c() == 100);
DLIB_TEST(trainer.get_num_threads() == 2);
DLIB_TEST(trainer.get_max_cache_size() == 40);
trainer.set_forces_assignment(force_assignment);
trainer.set_num_threads(3);
trainer.set_c(50);
DLIB_TEST(trainer.get_c() == 50);
DLIB_TEST(trainer.get_num_threads() == 3);
DLIB_TEST(trainer.forces_assignment() == force_assignment);
assignment_function<fe_type> ass = trainer.train(samples, labels);
for (unsigned long i = 0; i < samples.size(); ++i)
{
std::vector<long> out = ass(samples[i]);
dlog << LINFO << "true labels: " << trans(vector_to_matrix(labels[i]));
dlog << LINFO << "pred labels: " << trans(vector_to_matrix(out));
DLIB_TEST(trans(vector_to_matrix(labels[i])) == trans(vector_to_matrix(out)));
}
double accuracy;
dlog << LINFO << "samples.size(): "<< samples.size();
accuracy = test_assignment_function(ass, samples, labels);
dlog << LINFO << "accuracy: "<< accuracy;
DLIB_TEST(accuracy == 1);
accuracy = cross_validate_assignment_trainer(trainer, samples, labels, 3);
dlog << LINFO << "cv accuracy: "<< accuracy;
DLIB_TEST(accuracy == 1);
ostringstream sout;
serialize(ass, sout);
istringstream sin(sout.str());
assignment_function<fe_type> ass2;
deserialize(ass2, sin);
DLIB_TEST(ass2.forces_assignment() == ass.forces_assignment());
DLIB_TEST(length(ass2.get_weights() - ass.get_weights()) < 1e-10);
for (unsigned long i = 0; i < samples.size(); ++i)
{
std::vector<long> out = ass2(samples[i]);
dlog << LINFO << "true labels: " << trans(vector_to_matrix(labels[i]));
dlog << LINFO << "pred labels: " << trans(vector_to_matrix(out));
DLIB_TEST(trans(vector_to_matrix(labels[i])) == trans(vector_to_matrix(out)));
}
}
// ----------------------------------------------------------------------------------------
class test_assignment_learning : public tester
{
public:
test_assignment_learning (
) :
tester ("test_assignment_learning",
"Runs tests on the assignment learning code.")
{}
void perform_test (
)
{
test1<feature_extractor_dense>(make_data, false);
test1<feature_extractor_sparse>(make_data, false);
test1<feature_extractor_dense>(make_data_force, false);
test1<feature_extractor_sparse>(make_data_force, false);
test1<feature_extractor_dense>(make_data_force, true);
test1<feature_extractor_sparse>(make_data_force, true);
}
} a;
// ----------------------------------------------------------------------------------------
}
...@@ -29,6 +29,7 @@ SRC += any.cpp ...@@ -29,6 +29,7 @@ SRC += any.cpp
SRC += any_function.cpp SRC += any_function.cpp
SRC += array2d.cpp SRC += array2d.cpp
SRC += array.cpp SRC += array.cpp
SRC += assignment_learning.cpp
SRC += base64.cpp SRC += base64.cpp
SRC += bayes_nets.cpp SRC += bayes_nets.cpp
SRC += bigint.cpp SRC += bigint.cpp
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment