Commit 81aa8d72 authored by Davis King's avatar Davis King

Added a one vs. all multiclass trainer.

extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404070
parent c17214c3
......@@ -34,6 +34,8 @@
#include "svm/cross_validate_multiclass_trainer.h"
#include "svm/cross_validate_regression_trainer.h"
#include "svm/one_vs_all_decision_function.h"
#include "svm/one_vs_all_trainer.h"
#endif // DLIB_SVm_HEADER
// Copyright (C) 2010 Davis E. King (
// License: Boost Software License See LICENSE.txt for the full license.
#include "one_vs_all_decision_function_abstract.h"
#include "../serialize.h"
#include "../type_safe_union.h"
#include <sstream>
#include <map>
#include "../any.h"
#include "null_df.h"
namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename one_vs_all_trainer,
typename DF1 = null_df, typename DF2 = null_df, typename DF3 = null_df,
typename DF4 = null_df, typename DF5 = null_df, typename DF6 = null_df,
typename DF7 = null_df, typename DF8 = null_df, typename DF9 = null_df,
typename DF10 = null_df
class one_vs_all_decision_function
typedef typename one_vs_all_trainer::label_type label_type;
typedef typename one_vs_all_trainer::sample_type sample_type;
typedef typename one_vs_all_trainer::scalar_type scalar_type;
typedef typename one_vs_all_trainer::mem_manager_type mem_manager_type;
typedef std::map<label_type, any_decision_function<sample_type, scalar_type> > binary_function_table;
one_vs_all_decision_function() :num_classes(0) {}
explicit one_vs_all_decision_function(
const binary_function_table& dfs_
) : dfs(dfs_)
num_classes = dfs.size();
const binary_function_table& get_binary_decision_functions (
) const
return dfs;
const std::vector<label_type> get_labels (
) const
std::vector<label_type> temp;
for (typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i)
return temp;
template <
typename df1, typename df2, typename df3, typename df4, typename df5,
typename df6, typename df7, typename df8, typename df9, typename df10
one_vs_all_decision_function (
const one_vs_all_decision_function<one_vs_all_trainer,
df1, df2, df3, df4, df5,
df6, df7, df8, df9, df10>& item
) : dfs(item.get_binary_decision_functions()), num_classes(item.number_of_classes()) {}
unsigned long number_of_classes (
) const
return num_classes;
label_type operator() (
const sample_type& sample
) const
DLIB_ASSERT(number_of_classes() != 0,
"\t void one_vs_all_decision_function::operator()"
<< "\n\t You can't make predictions with an empty decision function."
<< "\n\t this: " << this
label_type best_label = label_type();
scalar_type best_score = -std::numeric_limits<scalar_type>::infinity();
// run all the classifiers over the sample and find the best one
for(typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i)
const scalar_type score = i->second(sample);
if (score > best_score)
best_score = score;
best_label = i->first;
return best_label;
binary_function_table dfs;
unsigned long num_classes;
// ----------------------------------------------------------------------------------------
template <
typename T,
typename DF1, typename DF2, typename DF3,
typename DF4, typename DF5, typename DF6,
typename DF7, typename DF8, typename DF9,
typename DF10
void serialize(
const one_vs_all_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item,
std::ostream& out
type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp;
typedef typename T::label_type label_type;
typedef typename T::sample_type sample_type;
typedef typename T::scalar_type scalar_type;
typedef std::map<label_type, any_decision_function<sample_type, scalar_type> > binary_function_table;
const unsigned long version = 1;
serialize(version, out);
const unsigned long size = item.get_binary_decision_functions().size();
serialize(size, out);
for(typename binary_function_table::const_iterator i = item.get_binary_decision_functions().begin();
i != item.get_binary_decision_functions().end(); ++i)
serialize(i->first, out);
if (i->second.template contains<DF1>()) temp.template get<DF1>() = any_cast<DF1>(i->second);
else if (i->second.template contains<DF2>()) temp.template get<DF2>() = any_cast<DF2>(i->second);
else if (i->second.template contains<DF3>()) temp.template get<DF3>() = any_cast<DF3>(i->second);
else if (i->second.template contains<DF4>()) temp.template get<DF4>() = any_cast<DF4>(i->second);
else if (i->second.template contains<DF5>()) temp.template get<DF5>() = any_cast<DF5>(i->second);
else if (i->second.template contains<DF6>()) temp.template get<DF6>() = any_cast<DF6>(i->second);
else if (i->second.template contains<DF7>()) temp.template get<DF7>() = any_cast<DF7>(i->second);
else if (i->second.template contains<DF8>()) temp.template get<DF8>() = any_cast<DF8>(i->second);
else if (i->second.template contains<DF9>()) temp.template get<DF9>() = any_cast<DF9>(i->second);
else if (i->second.template contains<DF10>()) temp.template get<DF10>() = any_cast<DF10>(i->second);
else throw serialization_error("Can't serialize one_vs_all_decision_function. Not all decision functions defined.");
catch (serialization_error& e)
throw serialization_error( + "\n while serializing an object of type one_vs_all_decision_function");
// ----------------------------------------------------------------------------------------
namespace impl_ova
template <typename sample_type, typename scalar_type>
struct copy_to_df_helper
copy_to_df_helper(any_decision_function<sample_type, scalar_type>& target_) : target(target_) {}
mutable any_decision_function<sample_type, scalar_type>& target;
template <typename T>
void operator() (
const T& item
) const
target = item;
template <
typename T,
typename DF1, typename DF2, typename DF3,
typename DF4, typename DF5, typename DF6,
typename DF7, typename DF8, typename DF9,
typename DF10
void deserialize(
one_vs_all_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item,
std::istream& in
type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp;
typedef typename T::label_type label_type;
typedef typename T::sample_type sample_type;
typedef typename T::scalar_type scalar_type;
typedef impl_ova::copy_to_df_helper<sample_type, scalar_type> copy_to;
unsigned long version;
deserialize(version, in);
if (version != 1)
throw serialization_error("Can't deserialize one_vs_all_decision_function. Wrong version.");
unsigned long size;
deserialize(size, in);
typedef std::map<label_type, any_decision_function<sample_type, scalar_type> > binary_function_table;
binary_function_table dfs;
label_type l;
for (unsigned long i = 0; i < size; ++i)
deserialize(l, in);
deserialize(temp, in);
if (temp.template contains<null_df>())
throw serialization_error("A sub decision function of unknown type was encountered.");
item = one_vs_all_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>(dfs);
catch (serialization_error& e)
throw serialization_error( + "\n while deserializing an object of type one_vs_all_decision_function");
// ----------------------------------------------------------------------------------------
// Copyright (C) 2010 Davis E. King (
// License: Boost Software License See LICENSE.txt for the full license.
#include "../serialize.h"
#include <map>
#include "../any/any_decision_function_abstract.h"
#include "one_vs_all_trainer_abstract.h"
#include "null_df.h"
namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename one_vs_all_trainer,
typename DF1 = null_df, typename DF2 = null_df, typename DF3 = null_df,
typename DF4 = null_df, typename DF5 = null_df, typename DF6 = null_df,
typename DF7 = null_df, typename DF8 = null_df, typename DF9 = null_df,
typename DF10 = null_df
class one_vs_all_decision_function
REQUIREMENTS ON one_vs_all_trainer
This should be an instantiation of the one_vs_all_trainer template.
It is used to infer which types are used for various things, such as
representing labels.
These types can either be left at their default values or set
to any kind of decision function object capable of being
stored in an any_decision_function<sample_type,scalar_type>
object. These types should also be serializable.
This object represents a multiclass classifier built out of a set of
binary classifiers. Each binary classifier is used to vote for the
correct multiclass label using a one vs. all strategy. Therefore,
if you have N classes then there will be N binary classifiers inside
this object.
Note that the DF* template arguments are only used if you want
to serialize and deserialize one_vs_all_decision_function objects.
Specifically, all the types of binary decision function contained
within a one_vs_all_decision_function must be listed in the
template arguments if serialization and deserialization is to
be used.
typedef typename one_vs_all_trainer::label_type label_type;
typedef typename one_vs_all_trainer::sample_type sample_type;
typedef typename one_vs_all_trainer::scalar_type scalar_type;
typedef typename one_vs_all_trainer::mem_manager_type mem_manager_type;
typedef std::map<label_type, any_decision_function<sample_type, scalar_type> > binary_function_table;
- #number_of_classes() == 0
- #get_binary_decision_functions().size() == 0
- #get_labels().size() == 0
explicit one_vs_all_decision_function(
const binary_function_table& decision_functions
- #get_binary_decision_functions() == decision_functions
- #get_labels() == a list of all the labels which appear in the
given set of decision functions
- #number_of_classes() == #get_labels().size()
template <
typename df1, typename df2, typename df3, typename df4, typename df5,
typename df6, typename df7, typename df8, typename df9, typename df10
one_vs_all_decision_function (
const one_vs_all_decision_function<one_vs_all_trainer,
df1, df2, df3, df4, df5,
df6, df7, df8, df9, df10>& item
- #*this will be a copy of item
- #number_of_classes() == item.number_of_classes()
- #get_labels() == item.get_labels()
- #get_binary_decision_functions() == item.get_binary_decision_functions()
const binary_function_table& get_binary_decision_functions (
) const;
- returns the table of binary decision functions used by this
object. The label given to a test sample is computed by
determining which binary decision function has the largest
(i.e. most positive) output and returning the label associated
with that decision function.
const std::vector<label_type> get_labels (
) const;
- returns a vector containing all the labels which can be
predicted by this object.
unsigned long number_of_classes (
) const;
- returns get_labels().size()
(i.e. returns the number of different labels/classes predicted by
this object)
label_type operator() (
const sample_type& sample
) const
- number_of_classes() != 0
- evaluates all the decision functions in get_binary_decision_functions()
and returns the predicted label.
// ----------------------------------------------------------------------------------------
template <
typename T,
typename DF1, typename DF2, typename DF3,
typename DF4, typename DF5, typename DF6,
typename DF7, typename DF8, typename DF9,
typename DF10
void serialize(
const one_vs_all_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item,
std::ostream& out
- writes the given item to the output stream out.
- serialization_error.
This is thrown if there is a problem writing to the ostream or if item
contains a type of decision function not listed among the DF* template
// ----------------------------------------------------------------------------------------
template <
typename T,
typename DF1, typename DF2, typename DF3,
typename DF4, typename DF5, typename DF6,
typename DF7, typename DF8, typename DF9,
typename DF10
void deserialize(
one_vs_all_decision_function<T,DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10>& item,
std::istream& in
- deserializes a one_vs_all_decision_function from in and stores it in item.
- serialization_error.
This is thrown if there is a problem reading from the istream or if the
serialized data contains decision functions not listed among the DF*
template arguments.
// ----------------------------------------------------------------------------------------
// Copyright (C) 2010 Davis E. King (
// License: Boost Software License See LICENSE.txt for the full license.
#include "one_vs_all_trainer_abstract.h"
#include "one_vs_all_decision_function.h"
#include <vector>
#include "multiclass_tools.h"
#include <sstream>
#include <iostream>
#include "../any.h"
#include <map>
#include <set>
namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename any_trainer,
typename label_type_ = double
class one_vs_all_trainer
typedef label_type_ label_type;
typedef typename any_trainer::sample_type sample_type;
typedef typename any_trainer::scalar_type scalar_type;
typedef typename any_trainer::mem_manager_type mem_manager_type;
typedef one_vs_all_decision_function<one_vs_all_trainer> trained_function_type;
one_vs_all_trainer (
) :
void set_trainer (
const any_trainer& trainer
default_trainer = trainer;
void set_trainer (
const any_trainer& trainer,
const label_type& l
trainers[l] = trainer;
void be_verbose (
verbose = true;
void be_quiet (
verbose = false;
struct invalid_label : public dlib::error
invalid_label(const std::string& msg, const label_type& l_
) : dlib::error(msg), l(l_) {};
virtual ~invalid_label(
) throw() {}
label_type l;
trained_function_type train (
const std::vector<sample_type>& all_samples,
const std::vector<label_type>& all_labels
) const
const std::vector<label_type> distinct_labels = select_all_distinct_labels(all_labels);
std::vector<scalar_type> labels;
typename trained_function_type::binary_function_table dfs;
for (unsigned long i = 0; i < distinct_labels.size(); ++i)
const label_type l = distinct_labels[i];
// setup one of the one vs all training sets
for (unsigned long k = 0; k < all_samples.size(); ++k)
if (all_labels[k] == l)
if (verbose)
std::cout << "Training classifier for " << l << " vs. all" << std::endl;
// now train a binary classifier using the samples we selected
const typename binary_function_table::const_iterator itr = trainers.find(l);
if (itr != trainers.end())
dfs[l] = itr->second.train(all_samples, labels);
else if (default_trainer.is_empty() == false)
dfs[l] = default_trainer.train(all_samples, labels);
std::ostringstream sout;
sout << "In one_vs_all_trainer, no trainer registered for the " << l << " label.";
throw invalid_label(sout.str(), l);
return trained_function_type(dfs);
any_trainer default_trainer;
typedef std::map<label_type, any_trainer> binary_function_table;
binary_function_table trainers;
bool verbose;
// ----------------------------------------------------------------------------------------
// Copyright (C) 2010 Davis E. King (
// License: Boost Software License See LICENSE.txt for the full license.
#include "one_vs_all_decision_function_abstract.h"
#include <vector>
#include "../any/any_trainer_abstract.h"
namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename any_trainer,
typename label_type_ = double
class one_vs_all_trainer
must be an instantiation of the dlib::any_trainer template.
label_type_ must be default constructable, copyable, and comparable using
operator < and ==. It must also be possible to write it to an std::ostream
using operator<<.
This object is a tool for turning a bunch of binary classifiers into a
multiclass classifier. It does this by training the binary classifiers
in a one vs. all fashion. That is, if you have N possible classes then
it trains N binary classifiers which are then used to vote on the identity
of a test sample.
This object works with any kind of binary classification trainer object
capable of being assigned to an any_trainer object. (e.g. the svm_nu_trainer)
typedef label_type_ label_type;
typedef typename any_trainer::sample_type sample_type;
typedef typename any_trainer::scalar_type scalar_type;
typedef typename any_trainer::mem_manager_type mem_manager_type;
typedef one_vs_all_decision_function<one_vs_all_trainer> trained_function_type;
one_vs_all_trainer (
- this object is properly initialized
- this object will not be verbose unless be_verbose() is called
- no binary trainers are associated with *this. I.e. you have to
call set_trainer() before calling train()
void set_trainer (
const any_trainer& trainer
- sets the trainer used for all binary subproblems. Any previous
calls to set_trainer() are overridden by this function. Even the
more specific set_trainer(trainer, l) form.
void set_trainer (
const any_trainer& trainer,
const label_type& l
- Sets the trainer object used to create a binary classifier to
distinguish l labeled samples from all other samples.
void be_verbose (
- This object will print status messages to standard out so that a
user can observe the progress of the algorithm.
void be_quiet (
- this object will not print anything to standard out
struct invalid_label : public dlib::error
This is the exception thrown by the train() function below.
label_type l;
trained_function_type train (
const std::vector<sample_type>& all_samples,
const std::vector<label_type>& all_labels
) const;
- is_learning_problem(all_samples, all_labels)
- trains a bunch of binary classifiers in a one vs all fashion to solve the given
multiclass classification problem.
- returns a one_vs_all_decision_function F with the following properties:
- F contains all the learned binary classifiers and can be used to predict
the labels of new samples.
- if (new_x is a sample predicted to have a label of L) then
- F(new_x) == L
- F.get_labels() == select_all_distinct_labels(all_labels)
- F.number_of_classes() == select_all_distinct_labels(all_labels).size()
- invalid_label
This exception is thrown if there are labels in all_labels which don't have
any corresponding trainer object. This will never happen if set_trainer(trainer)
has been called. However, if only the set_trainer(trainer,l) form has been
used then this exception is thrown if not all labels have been given a trainer.
invalid_label::l will contain the label which is missing a trainer object.
Additionally, the exception will contain an informative error message available
via invalid_label::what().
// ----------------------------------------------------------------------------------------
......@@ -63,6 +63,7 @@ set (tests
......@@ -72,6 +72,8 @@ SRC += md5.cpp
SRC += member_function_pointer.cpp
SRC += metaprogramming.cpp
SRC += multithreaded_object.cpp
SRC += one_vs_all_trainer.cpp
SRC += one_vs_one_trainer.cpp
SRC += optimization.cpp
SRC += optimization_test_functions.cpp
SRC += opt_qp_solver.cpp
// Copyright (C) 2010 Davis E. King (
// License: Boost Software License See LICENSE.txt for the full license.
#include "tester.h"
#include <dlib/svm.h>
#include <vector>
#include <sstream>
using namespace test;
using namespace dlib;
using namespace std;
dlib::logger dlog("test.one_vs_all_trainer");
class test_one_vs_all_trainer : public tester
This object represents a unit test. When it is constructed
it adds itself into the testing framework.
test_one_vs_all_trainer (
) :
tester (
"test_one_vs_all_trainer", // the command line argument name for this test
"Run tests on the one_vs_all_trainer stuff.", // the command line argument description
0 // the number of command line arguments for this test
template <typename sample_type, typename label_type>
void generate_data (
std::vector<sample_type>& samples,
std::vector<label_type>& labels
const long num = 50;
sample_type m;
dlib::rand::float_1a rnd;
// make some samples near the origin
double radius = 0.5;
for (long i = 0; i < num+10; ++i)
double sign = 1;
if (rnd.get_random_double() < 0.5)
sign = -1;
m(0) = 2*radius*rnd.get_random_double()-radius;
m(1) = sign*sqrt(radius*radius - m(0)*m(0));
// add this sample to our set of samples we will run k-means
// make some samples in a circle around the origin but far away
radius = 10.0;
for (long i = 0; i < num+20; ++i)
double sign = 1;
if (rnd.get_random_double() < 0.5)
sign = -1;
m(0) = 2*radius*rnd.get_random_double()-radius;
m(1) = sign*sqrt(radius*radius - m(0)*m(0));
// add this sample to our set of samples we will run k-means
// make some samples in a circle around the point (25,25)
radius = 4.0;
for (long i = 0; i < num+30; ++i)
double sign = 1;
if (rnd.get_random_double() < 0.5)
sign = -1;
m(0) = 2*radius*rnd.get_random_double()-radius;
m(1) = sign*sqrt(radius*radius - m(0)*m(0));
// translate this point away from the origin
m(0) += 25;
m(1) += 25;
// add this sample to our set of samples we will run k-means
template <typename label_type, typename scalar_type>
void run_test (
typedef matrix<scalar_type,2,1> sample_type;
std::vector<sample_type> samples;
std::vector<label_type> labels;
// First, get our labeled set of training data
generate_data(samples, labels);
typedef one_vs_all_trainer<any_trainer<sample_type,scalar_type>,label_type > ova_trainer;
ova_trainer trainer;
typedef polynomial_kernel<sample_type> poly_kernel;
typedef radial_basis_kernel<sample_type> rbf_kernel;
// make the binary trainers and set some parameters
krr_trainer<rbf_kernel> rbf_trainer;
svm_nu_trainer<poly_kernel> poly_trainer;
poly_trainer.set_kernel(poly_kernel(0.1, 1, 2));
trainer.set_trainer(poly_trainer, 1);
randomize_samples(samples, labels);
matrix<scalar_type> res = cross_validate_multiclass_trainer(trainer, samples, labels, 2);
matrix<scalar_type> ans(3,3);
ans = 60, 0, 0,
0, 70, 0,
0, 0, 80;
DLIB_TEST_MSG(ans == res, "res: \n" << res);
one_vs_all_decision_function<ova_trainer> df = trainer.train(samples, labels);
DLIB_TEST(df.number_of_classes() == 3);
DLIB_TEST(df(samples[0]) == labels[0])
DLIB_TEST(df(samples[90]) == labels[90])
decision_function<poly_kernel>, // This is the output of the poly_trainer
decision_function<rbf_kernel> // This is the output of the rbf_trainer
> df2, df3;
df2 = df;
ofstream fout("df.dat", ios::binary);
serialize(df2, fout);
// load the function back in from disk and store it in df3.
ifstream fin("df.dat", ios::binary);
deserialize(df3, fin);
DLIB_TEST(df3(samples[0]) == labels[0])
DLIB_TEST(df3(samples[90]) == labels[90])
res = test_multiclass_decision_function(df3, samples, labels);
DLIB_TEST(res == ans);
void perform_test (
dlog << LINFO << "run_test<double,double>()";
dlog << LINFO << "run_test<int,double>()";
dlog << LINFO << "run_test<double,float>()";
dlog << LINFO << "run_test<int,float>()";
test_one_vs_all_trainer a;
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment