Commit 81aa8d72 authored by Davis King's avatar Davis King

Added a one vs. all multiclass trainer.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404070
parent c17214c3
......@@ -34,6 +34,8 @@
#include "svm/cross_validate_multiclass_trainer.h"
#include "svm/cross_validate_regression_trainer.h"
#include "svm/one_vs_all_decision_function.h"
#include "svm/one_vs_all_trainer.h"
#endif // DLIB_SVm_HEADER
......
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_ONE_VS_ALL_DECISION_FUnCTION_H__
#define DLIB_ONE_VS_ALL_DECISION_FUnCTION_H__
#include "one_vs_all_decision_function_abstract.h"
#include "../serialize.h"
#include "../type_safe_union.h"
#include <sstream>
#include <map>
#include "../any.h"
#include "null_df.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename one_vs_all_trainer,
typename DF1 = null_df, typename DF2 = null_df, typename DF3 = null_df,
typename DF4 = null_df, typename DF5 = null_df, typename DF6 = null_df,
typename DF7 = null_df, typename DF8 = null_df, typename DF9 = null_df,
typename DF10 = null_df
>
class one_vs_all_decision_function
{
public:
typedef typename one_vs_all_trainer::label_type label_type;
typedef typename one_vs_all_trainer::sample_type sample_type;
typedef typename one_vs_all_trainer::scalar_type scalar_type;
typedef typename one_vs_all_trainer::mem_manager_type mem_manager_type;
typedef std::map<label_type, any_decision_function<sample_type, scalar_type> > binary_function_table;
one_vs_all_decision_function() :num_classes(0) {}
explicit one_vs_all_decision_function(
const binary_function_table& dfs_
) : dfs(dfs_)
{
num_classes = dfs.size();
}
const binary_function_table& get_binary_decision_functions (
) const
{
return dfs;
}
const std::vector<label_type> get_labels (
) const
{
std::vector<label_type> temp;
temp.reserve(dfs.size());
for (typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i)
{
temp.push_back(i->first);
}
return temp;
}
template <
typename df1, typename df2, typename df3, typename df4, typename df5,
typename df6, typename df7, typename df8, typename df9, typename df10
>
one_vs_all_decision_function (
const one_vs_all_decision_function<one_vs_all_trainer,
df1, df2, df3, df4, df5,
df6, df7, df8, df9, df10>& item
) : dfs(item.get_binary_decision_functions()), num_classes(item.number_of_classes()) {}
unsigned long number_of_classes (
) const
{
return num_classes;
}
label_type operator() (
const sample_type& sample
) const
{
DLIB_ASSERT(number_of_classes() != 0,
"\t void one_vs_all_decision_function::operator()"
<< "\n\t You can't make predictions with an empty decision function."
<< "\n\t this: " << this
);
label_type best_label = label_type();
scalar_type best_score = -std::numeric_limits<scalar_type>::infinity();
// run all the classifiers over the sample and find the best one
for(typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i)
{
const scalar_type score = i->second(sample);
if (score > best_score)
{
best_score = score;
best_label = i->first;
}
}
return best_label;
}
private:
binary_function_table dfs;
unsigned long num_classes;
};
// ----------------------------------------------------------------------------------------
template <
typename T,
typename DF1, typename DF2, typename DF3,
typename DF4, typename DF5, typename DF6,
typename DF7, typename DF8, typename DF9,
typename DF10
>
void serialize(
const one_vs_all_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item,
std::ostream& out
)
{
try
{
type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp;
typedef typename T::label_type label_type;
typedef typename T::sample_type sample_type;
typedef typename T::scalar_type scalar_type;
typedef std::map<label_type, any_decision_function<sample_type, scalar_type> > binary_function_table;
const unsigned long version = 1;
serialize(version, out);
const unsigned long size = item.get_binary_decision_functions().size();
serialize(size, out);
for(typename binary_function_table::const_iterator i = item.get_binary_decision_functions().begin();
i != item.get_binary_decision_functions().end(); ++i)
{
serialize(i->first, out);
if (i->second.template contains<DF1>()) temp.template get<DF1>() = any_cast<DF1>(i->second);
else if (i->second.template contains<DF2>()) temp.template get<DF2>() = any_cast<DF2>(i->second);
else if (i->second.template contains<DF3>()) temp.template get<DF3>() = any_cast<DF3>(i->second);
else if (i->second.template contains<DF4>()) temp.template get<DF4>() = any_cast<DF4>(i->second);
else if (i->second.template contains<DF5>()) temp.template get<DF5>() = any_cast<DF5>(i->second);
else if (i->second.template contains<DF6>()) temp.template get<DF6>() = any_cast<DF6>(i->second);
else if (i->second.template contains<DF7>()) temp.template get<DF7>() = any_cast<DF7>(i->second);
else if (i->second.template contains<DF8>()) temp.template get<DF8>() = any_cast<DF8>(i->second);
else if (i->second.template contains<DF9>()) temp.template get<DF9>() = any_cast<DF9>(i->second);
else if (i->second.template contains<DF10>()) temp.template get<DF10>() = any_cast<DF10>(i->second);
else throw serialization_error("Can't serialize one_vs_all_decision_function. Not all decision functions defined.");
serialize(temp,out);
}
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing an object of type one_vs_all_decision_function");
}
}
// ----------------------------------------------------------------------------------------
namespace impl_ova
{
template <typename sample_type, typename scalar_type>
struct copy_to_df_helper
{
copy_to_df_helper(any_decision_function<sample_type, scalar_type>& target_) : target(target_) {}
mutable any_decision_function<sample_type, scalar_type>& target;
template <typename T>
void operator() (
const T& item
) const
{
target = item;
}
};
}
template <
typename T,
typename DF1, typename DF2, typename DF3,
typename DF4, typename DF5, typename DF6,
typename DF7, typename DF8, typename DF9,
typename DF10
>
void deserialize(
one_vs_all_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item,
std::istream& in
)
{
try
{
type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp;
typedef typename T::label_type label_type;
typedef typename T::sample_type sample_type;
typedef typename T::scalar_type scalar_type;
typedef impl_ova::copy_to_df_helper<sample_type, scalar_type> copy_to;
unsigned long version;
deserialize(version, in);
if (version != 1)
throw serialization_error("Can't deserialize one_vs_all_decision_function. Wrong version.");
unsigned long size;
deserialize(size, in);
typedef std::map<label_type, any_decision_function<sample_type, scalar_type> > binary_function_table;
binary_function_table dfs;
label_type l;
for (unsigned long i = 0; i < size; ++i)
{
deserialize(l, in);
deserialize(temp, in);
if (temp.template contains<null_df>())
throw serialization_error("A sub decision function of unknown type was encountered.");
temp.apply_to_contents(copy_to(dfs[l]));
}
item = one_vs_all_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>(dfs);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing an object of type one_vs_all_decision_function");
}
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_ONE_VS_ALL_DECISION_FUnCTION_H__
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_ONE_VS_ALL_DECISION_FUnCTION_ABSTRACT_H__
#ifdef DLIB_ONE_VS_ALL_DECISION_FUnCTION_ABSTRACT_H__
#include "../serialize.h"
#include <map>
#include "../any/any_decision_function_abstract.h"
#include "one_vs_all_trainer_abstract.h"
#include "null_df.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename one_vs_all_trainer,
typename DF1 = null_df, typename DF2 = null_df, typename DF3 = null_df,
typename DF4 = null_df, typename DF5 = null_df, typename DF6 = null_df,
typename DF7 = null_df, typename DF8 = null_df, typename DF9 = null_df,
typename DF10 = null_df
>
class one_vs_all_decision_function
{
/*!
REQUIREMENTS ON one_vs_all_trainer
This should be an instantiation of the one_vs_all_trainer template.
It is used to infer which types are used for various things, such as
representing labels.
REQUIREMENTS ON DF*
These types can either be left at their default values or set
to any kind of decision function object capable of being
stored in an any_decision_function<sample_type,scalar_type>
object. These types should also be serializable.
WHAT THIS OBJECT REPRESENTS
This object represents a multiclass classifier built out of a set of
binary classifiers. Each binary classifier is used to vote for the
correct multiclass label using a one vs. all strategy. Therefore,
if you have N classes then there will be N binary classifiers inside
this object.
Note that the DF* template arguments are only used if you want
to serialize and deserialize one_vs_all_decision_function objects.
Specifically, all the types of binary decision function contained
within a one_vs_all_decision_function must be listed in the
template arguments if serialization and deserialization is to
be used.
!*/
public:
typedef typename one_vs_all_trainer::label_type label_type;
typedef typename one_vs_all_trainer::sample_type sample_type;
typedef typename one_vs_all_trainer::scalar_type scalar_type;
typedef typename one_vs_all_trainer::mem_manager_type mem_manager_type;
typedef std::map<label_type, any_decision_function<sample_type, scalar_type> > binary_function_table;
one_vs_all_decision_function(
);
/*!
ensures
- #number_of_classes() == 0
- #get_binary_decision_functions().size() == 0
- #get_labels().size() == 0
!*/
explicit one_vs_all_decision_function(
const binary_function_table& decision_functions
);
/*!
ensures
- #get_binary_decision_functions() == decision_functions
- #get_labels() == a list of all the labels which appear in the
given set of decision functions
- #number_of_classes() == #get_labels().size()
!*/
template <
typename df1, typename df2, typename df3, typename df4, typename df5,
typename df6, typename df7, typename df8, typename df9, typename df10
>
one_vs_all_decision_function (
const one_vs_all_decision_function<one_vs_all_trainer,
df1, df2, df3, df4, df5,
df6, df7, df8, df9, df10>& item
);
/*!
ensures
- #*this will be a copy of item
- #number_of_classes() == item.number_of_classes()
- #get_labels() == item.get_labels()
- #get_binary_decision_functions() == item.get_binary_decision_functions()
!*/
const binary_function_table& get_binary_decision_functions (
) const;
/*!
ensures
- returns the table of binary decision functions used by this
object. The label given to a test sample is computed by
determining which binary decision function has the largest
(i.e. most positive) output and returning the label associated
with that decision function.
!*/
const std::vector<label_type> get_labels (
) const;
/*!
ensures
- returns a vector containing all the labels which can be
predicted by this object.
!*/
unsigned long number_of_classes (
) const;
/*!
ensures
- returns get_labels().size()
(i.e. returns the number of different labels/classes predicted by
this object)
!*/
label_type operator() (
const sample_type& sample
) const
/*!
requires
- number_of_classes() != 0
ensures
- evaluates all the decision functions in get_binary_decision_functions()
and returns the predicted label.
!*/
};
// ----------------------------------------------------------------------------------------
template <
typename T,
typename DF1, typename DF2, typename DF3,
typename DF4, typename DF5, typename DF6,
typename DF7, typename DF8, typename DF9,
typename DF10
>
void serialize(
const one_vs_all_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item,
std::ostream& out
);
/*!
ensures
- writes the given item to the output stream out.
throws
- serialization_error.
This is thrown if there is a problem writing to the ostream or if item
contains a type of decision function not listed among the DF* template
arguments.
!*/
// ----------------------------------------------------------------------------------------
template <
typename T,
typename DF1, typename DF2, typename DF3,
typename DF4, typename DF5, typename DF6,
typename DF7, typename DF8, typename DF9,
typename DF10
>
void deserialize(
one_vs_all_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item,
std::istream& in
);
/*!
ensures
- deserializes a one_vs_all_decision_function from in and stores it in item.
throws
- serialization_error.
This is thrown if there is a problem reading from the istream or if the
serialized data contains decision functions not listed among the DF*
template arguments.
!*/
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_ONE_VS_ALL_DECISION_FUnCTION_ABSTRACT_H__
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_ONE_VS_ALL_TRAiNER_H__
#define DLIB_ONE_VS_ALL_TRAiNER_H__
#include "one_vs_all_trainer_abstract.h"
#include "one_vs_all_decision_function.h"
#include <vector>
#include "multiclass_tools.h"
#include <sstream>
#include <iostream>
#include "../any.h"
#include <map>
#include <set>
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename any_trainer,
typename label_type_ = double
>
class one_vs_all_trainer
{
public:
typedef label_type_ label_type;
typedef typename any_trainer::sample_type sample_type;
typedef typename any_trainer::scalar_type scalar_type;
typedef typename any_trainer::mem_manager_type mem_manager_type;
typedef one_vs_all_decision_function<one_vs_all_trainer> trained_function_type;
one_vs_all_trainer (
) :
verbose(false)
{}
void set_trainer (
const any_trainer& trainer
)
{
default_trainer = trainer;
trainers.clear();
}
void set_trainer (
const any_trainer& trainer,
const label_type& l
)
{
trainers[l] = trainer;
}
void be_verbose (
)
{
verbose = true;
}
void be_quiet (
)
{
verbose = false;
}
struct invalid_label : public dlib::error
{
invalid_label(const std::string& msg, const label_type& l_
) : dlib::error(msg), l(l_) {};
virtual ~invalid_label(
) throw() {}
label_type l;
};
trained_function_type train (
const std::vector<sample_type>& all_samples,
const std::vector<label_type>& all_labels
) const
{
const std::vector<label_type> distinct_labels = select_all_distinct_labels(all_labels);
std::vector<scalar_type> labels;
typename trained_function_type::binary_function_table dfs;
for (unsigned long i = 0; i < distinct_labels.size(); ++i)
{
labels.clear();
const label_type l = distinct_labels[i];
// setup one of the one vs all training sets
for (unsigned long k = 0; k < all_samples.size(); ++k)
{
if (all_labels[k] == l)
labels.push_back(+1);
else
labels.push_back(-1);
}
if (verbose)
{
std::cout << "Training classifier for " << l << " vs. all" << std::endl;
}
// now train a binary classifier using the samples we selected
const typename binary_function_table::const_iterator itr = trainers.find(l);
if (itr != trainers.end())
{
dfs[l] = itr->second.train(all_samples, labels);
}
else if (default_trainer.is_empty() == false)
{
dfs[l] = default_trainer.train(all_samples, labels);
}
else
{
std::ostringstream sout;
sout << "In one_vs_all_trainer, no trainer registered for the " << l << " label.";
throw invalid_label(sout.str(), l);
}
}
return trained_function_type(dfs);
}
private:
any_trainer default_trainer;
typedef std::map<label_type, any_trainer> binary_function_table;
binary_function_table trainers;
bool verbose;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_ONE_VS_ALL_TRAiNER_H__
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_ONE_VS_ALL_TRAiNER_ABSTRACT_H__
#ifdef DLIB_ONE_VS_ALL_TRAiNER_ABSTRACT_H__
#include "one_vs_all_decision_function_abstract.h"
#include <vector>
#include "../any/any_trainer_abstract.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename any_trainer,
typename label_type_ = double
>
class one_vs_all_trainer
{
/*!
REQUIREMENTS ON any_trainer
must be an instantiation of the dlib::any_trainer template.
REQUIREMENTS ON label_type_
label_type_ must be default constructable, copyable, and comparable using
operator < and ==. It must also be possible to write it to an std::ostream
using operator<<.
WHAT THIS OBJECT REPRESENTS
This object is a tool for turning a bunch of binary classifiers into a
multiclass classifier. It does this by training the binary classifiers
in a one vs. all fashion. That is, if you have N possible classes then
it trains N binary classifiers which are then used to vote on the identity
of a test sample.
This object works with any kind of binary classification trainer object
capable of being assigned to an any_trainer object. (e.g. the svm_nu_trainer)
!*/
public:
typedef label_type_ label_type;
typedef typename any_trainer::sample_type sample_type;
typedef typename any_trainer::scalar_type scalar_type;
typedef typename any_trainer::mem_manager_type mem_manager_type;
typedef one_vs_all_decision_function<one_vs_all_trainer> trained_function_type;
one_vs_all_trainer (
);
/*!
ensures
- this object is properly initialized
- this object will not be verbose unless be_verbose() is called
- no binary trainers are associated with *this. I.e. you have to
call set_trainer() before calling train()
!*/
void set_trainer (
const any_trainer& trainer
);
/*!
ensures
- sets the trainer used for all binary subproblems. Any previous
calls to set_trainer() are overridden by this function. Even the
more specific set_trainer(trainer, l) form.
!*/
void set_trainer (
const any_trainer& trainer,
const label_type& l
);
/*!
ensures
- Sets the trainer object used to create a binary classifier to
distinguish l labeled samples from all other samples.
!*/
void be_verbose (
);
/*!
ensures
- This object will print status messages to standard out so that a
user can observe the progress of the algorithm.
!*/
void be_quiet (
);
/*!
ensures
- this object will not print anything to standard out
!*/
struct invalid_label : public dlib::error
{
/*!
This is the exception thrown by the train() function below.
!*/
label_type l;
};
trained_function_type train (
const std::vector<sample_type>& all_samples,
const std::vector<label_type>& all_labels
) const;
/*!
requires
- is_learning_problem(all_samples, all_labels)
ensures
- trains a bunch of binary classifiers in a one vs all fashion to solve the given
multiclass classification problem.
- returns a one_vs_all_decision_function F with the following properties:
- F contains all the learned binary classifiers and can be used to predict
the labels of new samples.
- if (new_x is a sample predicted to have a label of L) then
- F(new_x) == L
- F.get_labels() == select_all_distinct_labels(all_labels)
- F.number_of_classes() == select_all_distinct_labels(all_labels).size()
throws
- invalid_label
This exception is thrown if there are labels in all_labels which don't have
any corresponding trainer object. This will never happen if set_trainer(trainer)
has been called. However, if only the set_trainer(trainer,l) form has been
used then this exception is thrown if not all labels have been given a trainer.
invalid_label::l will contain the label which is missing a trainer object.
Additionally, the exception will contain an informative error message available
via invalid_label::what().
!*/
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_ONE_VS_ALL_TRAiNER_ABSTRACT_H__
......@@ -63,6 +63,7 @@ set (tests
metaprogramming.cpp
multithreaded_object.cpp
one_vs_one_trainer.cpp
one_vs_all_trainer.cpp
optimization.cpp
optimization_test_functions.cpp
opt_qp_solver.cpp
......
......@@ -72,6 +72,8 @@ SRC += md5.cpp
SRC += member_function_pointer.cpp
SRC += metaprogramming.cpp
SRC += multithreaded_object.cpp
SRC += one_vs_all_trainer.cpp
SRC += one_vs_one_trainer.cpp
SRC += optimization.cpp
SRC += optimization_test_functions.cpp
SRC += opt_qp_solver.cpp
......
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include "tester.h"
#include <dlib/svm.h>
#include <vector>
#include <sstream>
namespace
{
using namespace test;
using namespace dlib;
using namespace std;
dlib::logger dlog("test.one_vs_all_trainer");
class test_one_vs_all_trainer : public tester
{
/*!
WHAT THIS OBJECT REPRESENTS
This object represents a unit test. When it is constructed
it adds itself into the testing framework.
!*/
public:
test_one_vs_all_trainer (
) :
tester (
"test_one_vs_all_trainer", // the command line argument name for this test
"Run tests on the one_vs_all_trainer stuff.", // the command line argument description
0 // the number of command line arguments for this test
)
{
}
template <typename sample_type, typename label_type>
void generate_data (
std::vector<sample_type>& samples,
std::vector<label_type>& labels
)
{
const long num = 50;
sample_type m;
dlib::rand::float_1a rnd;
// make some samples near the origin
double radius = 0.5;
for (long i = 0; i < num+10; ++i)
{
double sign = 1;
if (rnd.get_random_double() < 0.5)
sign = -1;
m(0) = 2*radius*rnd.get_random_double()-radius;
m(1) = sign*sqrt(radius*radius - m(0)*m(0));
// add this sample to our set of samples we will run k-means
samples.push_back(m);
labels.push_back(1);
}
// make some samples in a circle around the origin but far away
radius = 10.0;
for (long i = 0; i < num+20; ++i)
{
double sign = 1;
if (rnd.get_random_double() < 0.5)
sign = -1;
m(0) = 2*radius*rnd.get_random_double()-radius;
m(1) = sign*sqrt(radius*radius - m(0)*m(0));
// add this sample to our set of samples we will run k-means
samples.push_back(m);
labels.push_back(2);
}
// make some samples in a circle around the point (25,25)
radius = 4.0;
for (long i = 0; i < num+30; ++i)
{
double sign = 1;
if (rnd.get_random_double() < 0.5)
sign = -1;
m(0) = 2*radius*rnd.get_random_double()-radius;
m(1) = sign*sqrt(radius*radius - m(0)*m(0));
// translate this point away from the origin
m(0) += 25;
m(1) += 25;
// add this sample to our set of samples we will run k-means
samples.push_back(m);
labels.push_back(3);
}
}
template <typename label_type, typename scalar_type>
void run_test (
)
{
print_spinner();
typedef matrix<scalar_type,2,1> sample_type;
std::vector<sample_type> samples;
std::vector<label_type> labels;
// First, get our labeled set of training data
generate_data(samples, labels);
typedef one_vs_all_trainer<any_trainer<sample_type,scalar_type>,label_type > ova_trainer;
ova_trainer trainer;
typedef polynomial_kernel<sample_type> poly_kernel;
typedef radial_basis_kernel<sample_type> rbf_kernel;
// make the binary trainers and set some parameters
krr_trainer<rbf_kernel> rbf_trainer;
svm_nu_trainer<poly_kernel> poly_trainer;
poly_trainer.set_kernel(poly_kernel(0.1, 1, 2));
rbf_trainer.set_kernel(rbf_kernel(0.1));
trainer.set_trainer(rbf_trainer);
trainer.set_trainer(poly_trainer, 1);
randomize_samples(samples, labels);
matrix<scalar_type> res = cross_validate_multiclass_trainer(trainer, samples, labels, 2);
print_spinner();
matrix<scalar_type> ans(3,3);
ans = 60, 0, 0,
0, 70, 0,
0, 0, 80;
DLIB_TEST_MSG(ans == res, "res: \n" << res);
one_vs_all_decision_function<ova_trainer> df = trainer.train(samples, labels);
DLIB_TEST(df.number_of_classes() == 3);
DLIB_TEST(df(samples[0]) == labels[0])
DLIB_TEST(df(samples[90]) == labels[90])
one_vs_all_decision_function<ova_trainer,
decision_function<poly_kernel>, // This is the output of the poly_trainer
decision_function<rbf_kernel> // This is the output of the rbf_trainer
> df2, df3;
df2 = df;
ofstream fout("df.dat", ios::binary);
serialize(df2, fout);
fout.close();
// load the function back in from disk and store it in df3.
ifstream fin("df.dat", ios::binary);
deserialize(df3, fin);
DLIB_TEST(df3(samples[0]) == labels[0])
DLIB_TEST(df3(samples[90]) == labels[90])
res = test_multiclass_decision_function(df3, samples, labels);
DLIB_TEST(res == ans);
}
void perform_test (
)
{
dlog << LINFO << "run_test<double,double>()";
run_test<double,double>();
dlog << LINFO << "run_test<int,double>()";
run_test<int,double>();
dlog << LINFO << "run_test<double,float>()";
run_test<double,float>();
dlog << LINFO << "run_test<int,float>()";
run_test<int,float>();
}
};
test_one_vs_all_trainer a;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment