Commit 7e934b9d authored by Davis King's avatar Davis King

Added some tests for the LIBSVM formatted data IO functions

and related routines.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404174
parent 4427c91a
......@@ -32,6 +32,8 @@ set (tests
conditioning_class_c.cpp
conditioning_class.cpp
config_reader.cpp
create_iris_datafile.cpp
data_io.cpp
directed_graph.cpp
discriminant_pca.cpp
ekm_and_lisf.cpp
......
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <sstream>
#include <fstream>
#include <dlib/compress_stream.h>
#include <dlib/base64.h>
namespace
{
// This function returns the contents of the file 'iris.scale'
const std::string get_decoded_string()
{
dlib::base64::kernel_1a base64_coder;
dlib::compress_stream::kernel_1ea compressor;
std::ostringstream sout;
std::istringstream sin;
// The base64 encoded data from the file 'iris.scale' we want to decode and return.
sout << "MU66cCmT9lCXWJXwhdfOGELwlyExClbHEF1s9XoqxNDV7o8AdVVHws/C9oIKO5EShH1lI/QFTWk3";
sout << "8EUdVpSw/NpZCUa7O9nq5uO6SE0gfRAyryH+pfIVL9jPiQi8rBdagDf4kUd4eggz9glwYnKEE+US";
sout << "K4GUBnW33YDf/jMF2GIBLNvz69yGJj8RC5rOUeJxR4mlHrDmnEfgRSdFIfXk4OQ4V/XbOsE1bnhG";
sout << "fmACcu7nYv6M/043Z6o8oaBeoJ2XK/9UqOWGFOwfVpQ46fz1a0oTlOzyDbbzMiniLr8z5P/VYwYd";
sout << "iAE70MwxHXs6Ga3zMmD/h1WxB/uRRph39B1lPN1UXC7U6SIatmtGWY+JYpwBk6raAnR3sblTFBNs";
sout << "UdPW+1a7AxinR0NZO6YEiCFy8lbpfPRZNAr5ENqPbD2DZtkHk3L8ARxSoFBgqPa8aO3fFow7rVxF";
sout << "xIJ2TxcHS84+BtH7KvtWfH7kUPOZLQ+Ohqghn9I57IeMl7E3aoTRTiVv3P2twAbP5Y+ZaAUoU7CK";
sout << "c9FptjKgMClUkuWxA7tGUEp069PqGT8NbI+yxorh/iVhkVhuGAzgjjXYS/D26OGj4bzF6mtRbnms";
sout << "Y2OYlF7QqhawZaHLtmZ6xLhR2F8p/0nrbpAz2brQLNKgQAMvU9rTZ0XYpuJNbRSsARkRDorPopDO";
sout << "kKNUORfkh2zfIytVToQ9tZ9W2LkfGZdWjJu/wEKjPDAU55q3bCfKOUk12tjq0sq/7qjUWJRcLSCu";
sout << "bqo8EzaKJj3cTXVgXXLHP6WEOPZ9vShuxQUu1JWkh8YEinjwFSyA6UnAKqPtN/HsBgv8YbnfnY/q";
sout << "e5JvUYWbs3Lk9enlhcI0vEVTV5f0GMjdkW87l3cWgmXJqiljJDREWEdKZJQ0rGBU/gW5kO3SAS1W";
sout << "OETVJG2kJD8Ib7hT15Mu2lOVNQYFri6O3yWtp5/NLHsYXoDKIYrxoJtM9+GkprVwRuhDcwxE+eQa";
sout << "pp5nC8qj38ameQHaJR2hJCuW2nvr4Wwm0ploF00ZP9cS9YznCO52cueUQX0+zil7bU++jghqSGP5";
sout << "+JyRzWUWWbDhnCyanej2Y3sqfZ3o2kuUjaAgZFz5pLqK64uACjztp4bQFsaMRdc+OCV2uItqoaRg";
sout << "a6u7/VrvS+ZigwcGWDjXSKev334f8ZqQQIR5hljdeseGuw7/5XySzUrgc8lCOvMa0pKNn9Nl8W/W";
sout << "vbKz1VwA";
// Put the data into the istream sin
sin.str(sout.str());
sout.str("");
// Decode the base64 text into its compressed binary form
base64_coder.decode(sin,sout);
sin.clear();
sin.str(sout.str());
sout.str("");
// Decompress the data into its original form
compressor.decompress(sin,sout);
// Return the decoded and decompressed data
return sout.str();
}
}
namespace dlib
{
void create_iris_datafile (
)
{
std::ofstream fout("iris.scale");
fout << get_decoded_string();
}
}
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CREATE_IRIS_DAtAFILE_H__
#define DLIB_CREATE_IRIS_DAtAFILE_H__
namespace dlib
{
void create_iris_datafile (
);
/*!
ensures
- Creates a local file called iris.scale that contains the
150 samples from the 3-class Iris dataset from the UCI
repository. The file will be in LIBSVM format (it was
originally downloaded from the LIBSVM website).
!*/
}
#endif // DLIB_CREATE_IRIS_DAtAFILE_H__
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include "tester.h"
#include <dlib/svm.h>
#include <dlib/data_io.h>
#include "create_iris_datafile.h"
#include <vector>
#include <sstream>
namespace
{
using namespace test;
using namespace dlib;
using namespace std;
dlib::logger dlog("test.data_io");
class test_data_io : public tester
{
/*!
WHAT THIS OBJECT REPRESENTS
This object represents a unit test. When it is constructed
it adds itself into the testing framework.
!*/
public:
test_data_io (
) :
tester (
"test_data_io", // the command line argument name for this test
"Run tests on the data_io stuff.", // the command line argument description
0 // the number of command line arguments for this test
)
{
}
template <typename sample_type>
void run_test()
{
print_spinner();
typedef typename sample_type::value_type::second_type scalar_type;
std::vector<sample_type> samples;
std::vector<scalar_type> labels;
load_libsvm_formatted_data("iris.scale",samples, labels);
save_libsvm_formatted_data("iris.scale2", samples, labels);
DLIB_TEST(samples.size() == 150);
DLIB_TEST(labels.size() == 150);
DLIB_TEST(sparse_vector::max_index_plus_one(samples) == 5);
fix_nonzero_indexing(samples);
DLIB_TEST(sparse_vector::max_index_plus_one(samples) == 4);
load_libsvm_formatted_data("iris.scale2",samples, labels);
DLIB_TEST(samples.size() == 150);
DLIB_TEST(labels.size() == 150);
DLIB_TEST(sparse_vector::max_index_plus_one(samples) == 5);
fix_nonzero_indexing(samples);
DLIB_TEST(sparse_vector::max_index_plus_one(samples) == 4);
one_vs_one_trainer<any_trainer<sample_type,scalar_type>,scalar_type> trainer;
typedef sparse_linear_kernel<sample_type> kernel_type;
trainer.set_trainer(krr_trainer<kernel_type>());
randomize_samples(samples, labels);
matrix<scalar_type> cv = cross_validate_multiclass_trainer(trainer, samples, labels, 4);
dlog << LINFO << "confusion matrix: \n" << cv;
const scalar_type cv_accuracy = sum(diag(cv))/sum(cv);
dlog << LINFO << "cv accuracy: " << cv_accuracy;
DLIB_TEST(cv_accuracy > 0.97);
{
print_spinner();
typedef matrix<scalar_type,0,1> dsample_type;
std::vector<dsample_type> dsamples = sparse_to_dense(samples);
DLIB_TEST(dsamples.size() == 150);
DLIB_TEST(dsamples[0].size() == 4);
DLIB_TEST(sparse_vector::max_index_plus_one(dsamples) == 4);
one_vs_one_trainer<any_trainer<dsample_type,scalar_type>,scalar_type> trainer;
typedef linear_kernel<dsample_type> kernel_type;
trainer.set_trainer(rr_trainer<kernel_type>());
cv = cross_validate_multiclass_trainer(trainer, dsamples, labels, 4);
dlog << LINFO << "dense confusion matrix: \n" << cv;
const scalar_type cv_accuracy = sum(diag(cv))/sum(cv);
dlog << LINFO << "dense cv accuracy: " << cv_accuracy;
DLIB_TEST(cv_accuracy > 0.97);
}
}
void perform_test (
)
{
print_spinner();
create_iris_datafile();
run_test<std::map<unsigned int, double> >();
run_test<std::map<unsigned int, float> >();
run_test<std::vector<std::pair<unsigned int, float> > >();
run_test<std::vector<std::pair<unsigned long, double> > >();
}
};
test_data_io a;
}
......@@ -42,6 +42,8 @@ SRC += compress_stream.cpp
SRC += conditioning_class_c.cpp
SRC += conditioning_class.cpp
SRC += config_reader.cpp
SRC += create_iris_datafile.cpp
SRC += data_io.cpp
SRC += directed_graph.cpp
SRC += discriminant_pca.cpp
SRC += ekm_and_lisf.cpp
......@@ -100,8 +102,8 @@ SRC += static_set.cpp
SRC += statistics.cpp
SRC += std_vector_c.cpp
SRC += string.cpp
SRC += svm.cpp
SRC += svm_c_linear.cpp
SRC += svm.cpp
SRC += symmetric_matrix_cache.cpp
SRC += thread_pool.cpp
SRC += threads.cpp
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment