Commit aed73cd0 authored by Davis King's avatar Davis King

This change is mainly about adding a result_type typedef to the various function objects.

Prior to this change, different function objects declared their return type in different ways,
now this has all been reconciled.  Now they all declare it as a public typedef named result_type.

I also simplified the cross_validate_multiclass_trainer(), cross_validate_trainer(),
test_binary_decision_function(), and test_multiclass_decision_function().  They now always
return double matrices regardless of any other consideration.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404175
parent 7e934b9d
...@@ -15,7 +15,7 @@ namespace dlib ...@@ -15,7 +15,7 @@ namespace dlib
template < template <
typename sample_type_, typename sample_type_,
typename scalar_type_ = double typename result_type_ = double
> >
class any_decision_function class any_decision_function
{ {
...@@ -23,7 +23,7 @@ namespace dlib ...@@ -23,7 +23,7 @@ namespace dlib
public: public:
typedef sample_type_ sample_type; typedef sample_type_ sample_type;
typedef scalar_type_ scalar_type; typedef result_type_ result_type;
typedef default_memory_manager mem_manager_type; typedef default_memory_manager mem_manager_type;
any_decision_function() any_decision_function()
...@@ -69,13 +69,13 @@ namespace dlib ...@@ -69,13 +69,13 @@ namespace dlib
return data.get() == 0; return data.get() == 0;
} }
scalar_type operator() ( result_type operator() (
const sample_type& item const sample_type& item
) const ) const
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT(is_empty() == false, DLIB_ASSERT(is_empty() == false,
"\t scalar_type any_decision_function::operator()" "\t result_type any_decision_function::operator()"
<< "\n\t You can't call operator() on an empty any_decision_function" << "\n\t You can't call operator() on an empty any_decision_function"
<< "\n\t this: " << this << "\n\t this: " << this
); );
...@@ -151,7 +151,7 @@ namespace dlib ...@@ -151,7 +151,7 @@ namespace dlib
scoped_ptr<base>& dest scoped_ptr<base>& dest
) const = 0; ) const = 0;
virtual scalar_type evaluate ( virtual result_type evaluate (
const sample_type& samp const sample_type& samp
) const = 0; ) const = 0;
}; };
...@@ -170,7 +170,7 @@ namespace dlib ...@@ -170,7 +170,7 @@ namespace dlib
dest.reset(new derived<T>(item)); dest.reset(new derived<T>(item));
} }
virtual scalar_type evaluate ( virtual result_type evaluate (
const sample_type& samp const sample_type& samp
) const ) const
{ {
...@@ -185,11 +185,11 @@ namespace dlib ...@@ -185,11 +185,11 @@ namespace dlib
template < template <
typename sample_type, typename sample_type,
typename scalar_type typename result_type
> >
inline void swap ( inline void swap (
any_decision_function<sample_type, scalar_type>& a, any_decision_function<sample_type, result_type>& a,
any_decision_function<sample_type, scalar_type>& b any_decision_function<sample_type, result_type>& b
) { a.swap(b); } ) { a.swap(b); }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -13,7 +13,7 @@ namespace dlib ...@@ -13,7 +13,7 @@ namespace dlib
template < template <
typename sample_type_, typename sample_type_,
typename scalar_type_ = double typename result_type_ = double
> >
class any_decision_function class any_decision_function
{ {
...@@ -26,7 +26,7 @@ namespace dlib ...@@ -26,7 +26,7 @@ namespace dlib
This object is a version of dlib::any that is restricted to containing This object is a version of dlib::any that is restricted to containing
elements which are some kind of function object with an operator() with elements which are some kind of function object with an operator() with
the following signature: the following signature:
scalar_type operator()(const sample_type&) const result_type operator()(const sample_type&) const
It is intended to be used to contain dlib::decision_function objects and It is intended to be used to contain dlib::decision_function objects and
other types which represent learned decision functions. It allows you other types which represent learned decision functions. It allows you
...@@ -37,7 +37,7 @@ namespace dlib ...@@ -37,7 +37,7 @@ namespace dlib
public: public:
typedef sample_type_ sample_type; typedef sample_type_ sample_type;
typedef scalar_type_ scalar_type; typedef result_type_ result_type;
typedef default_memory_manager mem_manager_type; typedef default_memory_manager mem_manager_type;
any_decision_function( any_decision_function(
...@@ -98,7 +98,7 @@ namespace dlib ...@@ -98,7 +98,7 @@ namespace dlib
- returns true - returns true
!*/ !*/
scalar_type operator() ( result_type operator() (
const sample_type& item const sample_type& item
) const; ) const;
/*! /*!
...@@ -174,11 +174,11 @@ namespace dlib ...@@ -174,11 +174,11 @@ namespace dlib
template < template <
typename sample_type, typename sample_type,
typename scalar_type typename result_type
> >
inline void swap ( inline void swap (
any_decision_function<sample_type,scalar_type>& a, any_decision_function<sample_type,result_type>& a,
any_decision_function<sample_type,scalar_type>& b any_decision_function<sample_type,result_type>& b
) { a.swap(b); } ) { a.swap(b); }
/*! /*!
provides a global swap function provides a global swap function
...@@ -189,10 +189,10 @@ namespace dlib ...@@ -189,10 +189,10 @@ namespace dlib
template < template <
typename T, typename T,
typename sample_type, typename sample_type,
typename scalar_type typename result_type
> >
T& any_cast( T& any_cast(
any_decision_function<sample_type,scalar_type>& a any_decision_function<sample_type,result_type>& a
) { return a.cast_to<T>(); } ) { return a.cast_to<T>(); }
/*! /*!
ensures ensures
...@@ -204,10 +204,10 @@ namespace dlib ...@@ -204,10 +204,10 @@ namespace dlib
template < template <
typename T, typename T,
typename sample_type, typename sample_type,
typename scalar_type typename result_type
> >
const T& any_cast( const T& any_cast(
const any_decision_function<sample_type,scalar_type>& a const any_decision_function<sample_type,result_type>& a
) { return a.cast_to<T>(); } ) { return a.cast_to<T>(); }
/*! /*!
ensures ensures
......
...@@ -514,6 +514,7 @@ namespace dlib ...@@ -514,6 +514,7 @@ namespace dlib
public: public:
typedef typename matrix_type::mem_manager_type mem_manager_type; typedef typename matrix_type::mem_manager_type mem_manager_type;
typedef typename matrix_type::type scalar_type; typedef typename matrix_type::type scalar_type;
typedef matrix_type result_type;
template <typename vector_type> template <typename vector_type>
void train ( void train (
...@@ -555,7 +556,7 @@ namespace dlib ...@@ -555,7 +556,7 @@ namespace dlib
return sd; return sd;
} }
const matrix_type& operator() ( const result_type& operator() (
const matrix_type& x const matrix_type& x
) const ) const
{ {
...@@ -666,6 +667,7 @@ namespace dlib ...@@ -666,6 +667,7 @@ namespace dlib
public: public:
typedef typename matrix_type::mem_manager_type mem_manager_type; typedef typename matrix_type::mem_manager_type mem_manager_type;
typedef typename matrix_type::type scalar_type; typedef typename matrix_type::type scalar_type;
typedef matrix<scalar_type,0,1,mem_manager_type> result_type;
template <typename vector_type> template <typename vector_type>
void train ( void train (
...@@ -722,7 +724,7 @@ namespace dlib ...@@ -722,7 +724,7 @@ namespace dlib
return pca; return pca;
} }
const matrix<scalar_type,0,1,mem_manager_type>& operator() ( const result_type& operator() (
const matrix_type& x const matrix_type& x
) const ) const
{ {
...@@ -838,7 +840,7 @@ namespace dlib ...@@ -838,7 +840,7 @@ namespace dlib
// This is just a temporary variable that doesn't contribute to the // This is just a temporary variable that doesn't contribute to the
// state of this object. // state of this object.
mutable matrix<scalar_type,0,1,mem_manager_type> temp_out; mutable result_type temp_out;
}; };
template < template <
......
...@@ -437,6 +437,7 @@ namespace dlib ...@@ -437,6 +437,7 @@ namespace dlib
public: public:
typedef typename matrix_type::mem_manager_type mem_manager_type; typedef typename matrix_type::mem_manager_type mem_manager_type;
typedef typename matrix_type::type scalar_type; typedef typename matrix_type::type scalar_type;
typedef matrix_type result_type;
template <typename vector_type> template <typename vector_type>
void train ( void train (
...@@ -495,7 +496,7 @@ namespace dlib ...@@ -495,7 +496,7 @@ namespace dlib
input feature shown to train() input feature shown to train()
!*/ !*/
const matrix_type& operator() ( const result_type& operator() (
const matrix_type& x const matrix_type& x
) const; ) const;
/*! /*!
...@@ -592,6 +593,7 @@ namespace dlib ...@@ -592,6 +593,7 @@ namespace dlib
public: public:
typedef typename matrix_type::mem_manager_type mem_manager_type; typedef typename matrix_type::mem_manager_type mem_manager_type;
typedef typename matrix_type::type scalar_type; typedef typename matrix_type::type scalar_type;
typedef matrix<scalar_type,0,1,mem_manager_type> result_type;
template <typename vector_type> template <typename vector_type>
void train ( void train (
...@@ -670,7 +672,7 @@ namespace dlib ...@@ -670,7 +672,7 @@ namespace dlib
matrix matrix
!*/ !*/
const matrix<scalar_type,0,1,mem_manager_type>& operator() ( const result_type& operator() (
const matrix_type& x const matrix_type& x
) const; ) const;
/*! /*!
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <vector> #include <vector>
#include "../matrix.h" #include "../matrix.h"
#include "one_vs_one_trainer.h" #include "one_vs_one_trainer.h"
#include "cross_validate_multiclass_trainer_abstract.h"
namespace dlib namespace dlib
{ {
...@@ -17,14 +18,12 @@ namespace dlib ...@@ -17,14 +18,12 @@ namespace dlib
typename sample_type, typename sample_type,
typename label_type typename label_type
> >
const matrix<typename dec_funct_type::scalar_type, 0, 0, typename dec_funct_type::mem_manager_type> const matrix<double> test_multiclass_decision_function (
test_multiclass_decision_function (
const dec_funct_type& dec_funct, const dec_funct_type& dec_funct,
const std::vector<sample_type>& x_test, const std::vector<sample_type>& x_test,
const std::vector<label_type>& y_test const std::vector<label_type>& y_test
) )
{ {
typedef typename dec_funct_type::scalar_type scalar_type;
typedef typename dec_funct_type::mem_manager_type mem_manager_type; typedef typename dec_funct_type::mem_manager_type mem_manager_type;
// make sure requires clause is not broken // make sure requires clause is not broken
...@@ -42,7 +41,7 @@ namespace dlib ...@@ -42,7 +41,7 @@ namespace dlib
for (unsigned long i = 0; i < all_labels.size(); ++i) for (unsigned long i = 0; i < all_labels.size(); ++i)
label_to_int[all_labels[i]] = i; label_to_int[all_labels[i]] = i;
matrix<typename dec_funct_type::scalar_type, 0, 0, typename dec_funct_type::mem_manager_type> res; matrix<double, 0, 0, typename dec_funct_type::mem_manager_type> res;
res.set_size(all_labels.size(), all_labels.size()); res.set_size(all_labels.size(), all_labels.size());
res = 0; res = 0;
...@@ -73,15 +72,13 @@ namespace dlib ...@@ -73,15 +72,13 @@ namespace dlib
typename sample_type, typename sample_type,
typename label_type typename label_type
> >
const matrix<typename trainer_type::scalar_type, 0, 0, typename trainer_type::mem_manager_type> const matrix<double> cross_validate_multiclass_trainer (
cross_validate_multiclass_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const std::vector<sample_type>& x, const std::vector<sample_type>& x,
const std::vector<label_type>& y, const std::vector<label_type>& y,
const long folds const long folds
) )
{ {
typedef typename trainer_type::scalar_type scalar_type;
typedef typename trainer_type::mem_manager_type mem_manager_type; typedef typename trainer_type::mem_manager_type mem_manager_type;
// make sure requires clause is not broken // make sure requires clause is not broken
...@@ -116,7 +113,7 @@ namespace dlib ...@@ -116,7 +113,7 @@ namespace dlib
std::vector<sample_type> x_test, x_train; std::vector<sample_type> x_test, x_train;
std::vector<label_type> y_test, y_train; std::vector<label_type> y_test, y_train;
matrix<scalar_type, 0, 0, mem_manager_type> res; matrix<double, 0, 0, mem_manager_type> res;
std::map<label_type,long> next_test_idx; std::map<label_type,long> next_test_idx;
for (unsigned long i = 0; i < all_labels.size(); ++i) for (unsigned long i = 0; i < all_labels.size(); ++i)
......
...@@ -16,8 +16,7 @@ namespace dlib ...@@ -16,8 +16,7 @@ namespace dlib
typename sample_type, typename sample_type,
typename label_type typename label_type
> >
const matrix<typename dec_funct_type::scalar_type, 0, 0, typename dec_funct_type::mem_manager_type> const matrix<double> test_multiclass_decision_function (
test_multiclass_decision_function (
const dec_funct_type& dec_funct, const dec_funct_type& dec_funct,
const std::vector<sample_type>& x_test, const std::vector<sample_type>& x_test,
const std::vector<label_type>& y_test const std::vector<label_type>& y_test
...@@ -46,8 +45,7 @@ namespace dlib ...@@ -46,8 +45,7 @@ namespace dlib
typename sample_type, typename sample_type,
typename label_type typename label_type
> >
const matrix<typename trainer_type::scalar_type, 0, 0, typename trainer_type::mem_manager_type> const matrix<double> cross_validate_multiclass_trainer (
cross_validate_multiclass_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const std::vector<sample_type>& x, const std::vector<sample_type>& x,
const std::vector<label_type>& y, const std::vector<label_type>& y,
......
...@@ -24,7 +24,6 @@ namespace dlib ...@@ -24,7 +24,6 @@ namespace dlib
const std::vector<label_type>& y_test const std::vector<label_type>& y_test
) )
{ {
typedef typename reg_funct_type::scalar_type scalar_type;
typedef typename reg_funct_type::mem_manager_type mem_manager_type; typedef typename reg_funct_type::mem_manager_type mem_manager_type;
// make sure requires clause is not broken // make sure requires clause is not broken
...@@ -62,7 +61,6 @@ namespace dlib ...@@ -62,7 +61,6 @@ namespace dlib
const long folds const long folds
) )
{ {
typedef typename trainer_type::scalar_type scalar_type;
typedef typename trainer_type::mem_manager_type mem_manager_type; typedef typename trainer_type::mem_manager_type mem_manager_type;
// make sure requires clause is not broken // make sure requires clause is not broken
......
...@@ -26,6 +26,7 @@ namespace dlib ...@@ -26,6 +26,7 @@ namespace dlib
{ {
typedef K kernel_type; typedef K kernel_type;
typedef typename K::scalar_type scalar_type; typedef typename K::scalar_type scalar_type;
typedef typename K::scalar_type result_type;
typedef typename K::sample_type sample_type; typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type; typedef typename K::mem_manager_type mem_manager_type;
...@@ -75,11 +76,11 @@ namespace dlib ...@@ -75,11 +76,11 @@ namespace dlib
return *this; return *this;
} }
scalar_type operator() ( result_type operator() (
const sample_type& x const sample_type& x
) const ) const
{ {
scalar_type temp = 0; result_type temp = 0;
for (long i = 0; i < alpha.nr(); ++i) for (long i = 0; i < alpha.nr(); ++i)
temp += alpha(i) * kernel_function(x,basis_vectors(i)); temp += alpha(i) * kernel_function(x,basis_vectors(i));
...@@ -137,6 +138,7 @@ namespace dlib ...@@ -137,6 +138,7 @@ namespace dlib
struct probabilistic_function struct probabilistic_function
{ {
typedef typename function_type::scalar_type scalar_type; typedef typename function_type::scalar_type scalar_type;
typedef typename function_type::result_type result_type;
typedef typename function_type::sample_type sample_type; typedef typename function_type::sample_type sample_type;
typedef typename function_type::mem_manager_type mem_manager_type; typedef typename function_type::mem_manager_type mem_manager_type;
...@@ -178,11 +180,11 @@ namespace dlib ...@@ -178,11 +180,11 @@ namespace dlib
return *this; return *this;
} }
scalar_type operator() ( result_type operator() (
const sample_type& x const sample_type& x
) const ) const
{ {
scalar_type f = decision_funct(x); result_type f = decision_funct(x);
return 1/(1 + std::exp(alpha*f + beta)); return 1/(1 + std::exp(alpha*f + beta));
} }
}; };
...@@ -236,6 +238,7 @@ namespace dlib ...@@ -236,6 +238,7 @@ namespace dlib
{ {
typedef K kernel_type; typedef K kernel_type;
typedef typename K::scalar_type scalar_type; typedef typename K::scalar_type scalar_type;
typedef typename K::scalar_type result_type;
typedef typename K::sample_type sample_type; typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type; typedef typename K::mem_manager_type mem_manager_type;
...@@ -285,11 +288,11 @@ namespace dlib ...@@ -285,11 +288,11 @@ namespace dlib
return *this; return *this;
} }
scalar_type operator() ( result_type operator() (
const sample_type& x const sample_type& x
) const ) const
{ {
scalar_type f = decision_funct(x); result_type f = decision_funct(x);
return 1/(1 + std::exp(alpha*f + beta)); return 1/(1 + std::exp(alpha*f + beta));
} }
}; };
...@@ -344,6 +347,7 @@ namespace dlib ...@@ -344,6 +347,7 @@ namespace dlib
public: public:
typedef K kernel_type; typedef K kernel_type;
typedef typename K::scalar_type scalar_type; typedef typename K::scalar_type scalar_type;
typedef typename K::scalar_type result_type;
typedef typename K::sample_type sample_type; typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type; typedef typename K::mem_manager_type mem_manager_type;
...@@ -462,11 +466,11 @@ namespace dlib ...@@ -462,11 +466,11 @@ namespace dlib
return *this; return *this;
} }
scalar_type operator() ( result_type operator() (
const sample_type& x const sample_type& x
) const ) const
{ {
scalar_type temp = 0; result_type temp = 0;
for (long i = 0; i < alpha.nr(); ++i) for (long i = 0; i < alpha.nr(); ++i)
temp += alpha(i) * kernel_function(x,basis_vectors(i)); temp += alpha(i) * kernel_function(x,basis_vectors(i));
...@@ -477,11 +481,11 @@ namespace dlib ...@@ -477,11 +481,11 @@ namespace dlib
return 0; return 0;
} }
scalar_type operator() ( result_type operator() (
const distance_function& x const distance_function& x
) const ) const
{ {
scalar_type temp = 0; result_type temp = 0;
for (long i = 0; i < alpha.nr(); ++i) for (long i = 0; i < alpha.nr(); ++i)
for (long j = 0; j < x.alpha.nr(); ++j) for (long j = 0; j < x.alpha.nr(); ++j)
temp += alpha(i)*x.alpha(j) * kernel_function(basis_vectors(i), x.basis_vectors(j)); temp += alpha(i)*x.alpha(j) * kernel_function(basis_vectors(i), x.basis_vectors(j));
...@@ -624,7 +628,7 @@ namespace dlib ...@@ -624,7 +628,7 @@ namespace dlib
> >
struct normalized_function struct normalized_function
{ {
typedef typename function_type::scalar_type scalar_type; typedef typename function_type::result_type result_type;
typedef typename function_type::sample_type sample_type; typedef typename function_type::sample_type sample_type;
typedef typename function_type::mem_manager_type mem_manager_type; typedef typename function_type::mem_manager_type mem_manager_type;
...@@ -646,7 +650,7 @@ namespace dlib ...@@ -646,7 +650,7 @@ namespace dlib
const function_type& funct const function_type& funct
) : normalizer(normalizer_), function(funct) {} ) : normalizer(normalizer_), function(funct) {}
scalar_type operator() ( result_type operator() (
const sample_type& x const sample_type& x
) const { return function(normalizer(x)); } ) const { return function(normalizer(x)); }
}; };
...@@ -706,6 +710,7 @@ namespace dlib ...@@ -706,6 +710,7 @@ namespace dlib
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type; typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
typedef matrix<scalar_type,0,0,mem_manager_type> scalar_matrix_type; typedef matrix<scalar_type,0,0,mem_manager_type> scalar_matrix_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type; typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
typedef scalar_vector_type result_type;
scalar_matrix_type weights; scalar_matrix_type weights;
K kernel_function; K kernel_function;
...@@ -727,7 +732,7 @@ namespace dlib ...@@ -727,7 +732,7 @@ namespace dlib
long out_vector_size ( long out_vector_size (
) const { return weights.nr(); } ) const { return weights.nr(); }
const scalar_vector_type& operator() ( const result_type& operator() (
const sample_type& x const sample_type& x
) const ) const
{ {
...@@ -741,7 +746,7 @@ namespace dlib ...@@ -741,7 +746,7 @@ namespace dlib
} }
private: private:
mutable scalar_vector_type temp1, temp2; mutable result_type temp1, temp2;
}; };
template < template <
......
...@@ -34,6 +34,7 @@ namespace dlib ...@@ -34,6 +34,7 @@ namespace dlib
typedef K kernel_type; typedef K kernel_type;
typedef typename K::scalar_type scalar_type; typedef typename K::scalar_type scalar_type;
typedef typename K::scalar_type result_type;
typedef typename K::sample_type sample_type; typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type; typedef typename K::mem_manager_type mem_manager_type;
...@@ -83,7 +84,7 @@ namespace dlib ...@@ -83,7 +84,7 @@ namespace dlib
- returns *this - returns *this
!*/ !*/
scalar_type operator() ( result_type operator() (
const sample_type& x const sample_type& x
) const ) const
/*! /*!
...@@ -92,7 +93,7 @@ namespace dlib ...@@ -92,7 +93,7 @@ namespace dlib
function contained in this object. function contained in this object.
!*/ !*/
{ {
scalar_type temp = 0; result_type temp = 0;
for (long i = 0; i < alpha.nr(); ++i) for (long i = 0; i < alpha.nr(); ++i)
temp += alpha(i) * kernel_function(x,basis_vectors(i)); temp += alpha(i) * kernel_function(x,basis_vectors(i));
...@@ -142,6 +143,7 @@ namespace dlib ...@@ -142,6 +143,7 @@ namespace dlib
!*/ !*/
typedef typename function_type::scalar_type scalar_type; typedef typename function_type::scalar_type scalar_type;
typedef typename function_type::result_type result_type;
typedef typename function_type::sample_type sample_type; typedef typename function_type::sample_type sample_type;
typedef typename function_type::mem_manager_type mem_manager_type; typedef typename function_type::mem_manager_type mem_manager_type;
...@@ -186,7 +188,7 @@ namespace dlib ...@@ -186,7 +188,7 @@ namespace dlib
- returns *this - returns *this
!*/ !*/
scalar_type operator() ( result_type operator() (
const sample_type& x const sample_type& x
) const ) const
/*! /*!
...@@ -198,7 +200,7 @@ namespace dlib ...@@ -198,7 +200,7 @@ namespace dlib
!*/ !*/
{ {
// Evaluate the normal decision function // Evaluate the normal decision function
scalar_type f = decision_funct(x); result_type f = decision_funct(x);
// Now basically normalize the output so that it is a properly // Now basically normalize the output so that it is a properly
// conditioned probability of x being in the +1 class given // conditioned probability of x being in the +1 class given
// the output of the decision function. // the output of the decision function.
...@@ -253,6 +255,7 @@ namespace dlib ...@@ -253,6 +255,7 @@ namespace dlib
typedef K kernel_type; typedef K kernel_type;
typedef typename K::scalar_type scalar_type; typedef typename K::scalar_type scalar_type;
typedef typename K::scalar_type result_type;
typedef typename K::sample_type sample_type; typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type; typedef typename K::mem_manager_type mem_manager_type;
...@@ -305,7 +308,7 @@ namespace dlib ...@@ -305,7 +308,7 @@ namespace dlib
- returns *this - returns *this
!*/ !*/
scalar_type operator() ( result_type operator() (
const sample_type& x const sample_type& x
) const ) const
/*! /*!
...@@ -317,7 +320,7 @@ namespace dlib ...@@ -317,7 +320,7 @@ namespace dlib
!*/ !*/
{ {
// Evaluate the normal decision function // Evaluate the normal decision function
scalar_type f = decision_funct(x); result_type f = decision_funct(x);
// Now basically normalize the output so that it is a properly // Now basically normalize the output so that it is a properly
// conditioned probability of x being in the +1 class given // conditioned probability of x being in the +1 class given
// the output of the decision function. // the output of the decision function.
...@@ -377,6 +380,7 @@ namespace dlib ...@@ -377,6 +380,7 @@ namespace dlib
public: public:
typedef K kernel_type; typedef K kernel_type;
typedef typename K::scalar_type scalar_type; typedef typename K::scalar_type scalar_type;
typedef typename K::scalar_type result_type;
typedef typename K::sample_type sample_type; typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type; typedef typename K::mem_manager_type mem_manager_type;
...@@ -523,7 +527,7 @@ namespace dlib ...@@ -523,7 +527,7 @@ namespace dlib
- returns *this - returns *this
!*/ !*/
scalar_type operator() ( result_type operator() (
const sample_type& x const sample_type& x
) const; ) const;
/*! /*!
...@@ -535,7 +539,7 @@ namespace dlib ...@@ -535,7 +539,7 @@ namespace dlib
space. space.
!*/ !*/
scalar_type operator() ( result_type operator() (
const distance_function& x const distance_function& x
) const; ) const;
/*! /*!
...@@ -661,7 +665,7 @@ namespace dlib ...@@ -661,7 +665,7 @@ namespace dlib
off to the contained function object. off to the contained function object.
!*/ !*/
typedef typename function_type::scalar_type scalar_type; typedef typename function_type::result_type result_type;
typedef typename function_type::sample_type sample_type; typedef typename function_type::sample_type sample_type;
typedef typename function_type::mem_manager_type mem_manager_type; typedef typename function_type::mem_manager_type mem_manager_type;
...@@ -701,7 +705,7 @@ namespace dlib ...@@ -701,7 +705,7 @@ namespace dlib
- returns *this - returns *this
!*/ !*/
scalar_type operator() ( result_type operator() (
const sample_type& x const sample_type& x
) const ) const
/*! /*!
...@@ -760,6 +764,7 @@ namespace dlib ...@@ -760,6 +764,7 @@ namespace dlib
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type; typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
typedef matrix<scalar_type,0,0,mem_manager_type> scalar_matrix_type; typedef matrix<scalar_type,0,0,mem_manager_type> scalar_matrix_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type; typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
typedef scalar_vector_type result_type;
scalar_matrix_type weights; scalar_matrix_type weights;
K kernel_function; K kernel_function;
...@@ -809,7 +814,7 @@ namespace dlib ...@@ -809,7 +814,7 @@ namespace dlib
(i.e. returns the dimensionality of the vectors output by this projection_function.) (i.e. returns the dimensionality of the vectors output by this projection_function.)
!*/ !*/
const scalar_vector_type& operator() ( const result_type& operator() (
const sample_type& x const sample_type& x
) const ) const
/*! /*!
...@@ -832,7 +837,7 @@ namespace dlib ...@@ -832,7 +837,7 @@ namespace dlib
} }
private: private:
mutable scalar_vector_type temp1, temp2; mutable result_type temp1, temp2;
}; };
template < template <
......
...@@ -28,12 +28,12 @@ namespace dlib ...@@ -28,12 +28,12 @@ namespace dlib
{ {
public: public:
typedef typename one_vs_all_trainer::label_type label_type; typedef typename one_vs_all_trainer::label_type result_type;
typedef typename one_vs_all_trainer::sample_type sample_type; typedef typename one_vs_all_trainer::sample_type sample_type;
typedef typename one_vs_all_trainer::scalar_type scalar_type; typedef typename one_vs_all_trainer::scalar_type scalar_type;
typedef typename one_vs_all_trainer::mem_manager_type mem_manager_type; typedef typename one_vs_all_trainer::mem_manager_type mem_manager_type;
typedef std::map<label_type, any_decision_function<sample_type, scalar_type> > binary_function_table; typedef std::map<result_type, any_decision_function<sample_type, scalar_type> > binary_function_table;
one_vs_all_decision_function() :num_classes(0) {} one_vs_all_decision_function() :num_classes(0) {}
...@@ -50,10 +50,10 @@ namespace dlib ...@@ -50,10 +50,10 @@ namespace dlib
return dfs; return dfs;
} }
const std::vector<label_type> get_labels ( const std::vector<result_type> get_labels (
) const ) const
{ {
std::vector<label_type> temp; std::vector<result_type> temp;
temp.reserve(dfs.size()); temp.reserve(dfs.size());
for (typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i) for (typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i)
{ {
...@@ -79,7 +79,7 @@ namespace dlib ...@@ -79,7 +79,7 @@ namespace dlib
return num_classes; return num_classes;
} }
label_type operator() ( result_type operator() (
const sample_type& sample const sample_type& sample
) const ) const
{ {
...@@ -89,7 +89,7 @@ namespace dlib ...@@ -89,7 +89,7 @@ namespace dlib
<< "\n\t this: " << this << "\n\t this: " << this
); );
label_type best_label = label_type(); result_type best_label = result_type();
scalar_type best_score = -std::numeric_limits<scalar_type>::infinity(); scalar_type best_score = -std::numeric_limits<scalar_type>::infinity();
// run all the classifiers over the sample and find the best one // run all the classifiers over the sample and find the best one
...@@ -132,10 +132,10 @@ namespace dlib ...@@ -132,10 +132,10 @@ namespace dlib
try try
{ {
type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp; type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp;
typedef typename T::label_type label_type; typedef typename T::label_type result_type;
typedef typename T::sample_type sample_type; typedef typename T::sample_type sample_type;
typedef typename T::scalar_type scalar_type; typedef typename T::scalar_type scalar_type;
typedef std::map<label_type, any_decision_function<sample_type, scalar_type> > binary_function_table; typedef std::map<result_type, any_decision_function<sample_type, scalar_type> > binary_function_table;
const unsigned long version = 1; const unsigned long version = 1;
serialize(version, out); serialize(version, out);
...@@ -206,7 +206,7 @@ namespace dlib ...@@ -206,7 +206,7 @@ namespace dlib
try try
{ {
type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp; type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp;
typedef typename T::label_type label_type; typedef typename T::label_type result_type;
typedef typename T::sample_type sample_type; typedef typename T::sample_type sample_type;
typedef typename T::scalar_type scalar_type; typedef typename T::scalar_type scalar_type;
typedef impl_ova::copy_to_df_helper<sample_type, scalar_type> copy_to; typedef impl_ova::copy_to_df_helper<sample_type, scalar_type> copy_to;
...@@ -220,10 +220,10 @@ namespace dlib ...@@ -220,10 +220,10 @@ namespace dlib
unsigned long size; unsigned long size;
deserialize(size, in); deserialize(size, in);
typedef std::map<label_type, any_decision_function<sample_type, scalar_type> > binary_function_table; typedef std::map<result_type, any_decision_function<sample_type, scalar_type> > binary_function_table;
binary_function_table dfs; binary_function_table dfs;
label_type l; result_type l;
for (unsigned long i = 0; i < size; ++i) for (unsigned long i = 0; i < size; ++i)
{ {
deserialize(l, in); deserialize(l, in);
......
...@@ -52,12 +52,12 @@ namespace dlib ...@@ -52,12 +52,12 @@ namespace dlib
!*/ !*/
public: public:
typedef typename one_vs_all_trainer::label_type label_type; typedef typename one_vs_all_trainer::label_type result_type;
typedef typename one_vs_all_trainer::sample_type sample_type; typedef typename one_vs_all_trainer::sample_type sample_type;
typedef typename one_vs_all_trainer::scalar_type scalar_type; typedef typename one_vs_all_trainer::scalar_type scalar_type;
typedef typename one_vs_all_trainer::mem_manager_type mem_manager_type; typedef typename one_vs_all_trainer::mem_manager_type mem_manager_type;
typedef std::map<label_type, any_decision_function<sample_type, scalar_type> > binary_function_table; typedef std::map<result_type, any_decision_function<sample_type, scalar_type> > binary_function_table;
one_vs_all_decision_function( one_vs_all_decision_function(
); );
...@@ -107,7 +107,7 @@ namespace dlib ...@@ -107,7 +107,7 @@ namespace dlib
with that decision function. with that decision function.
!*/ !*/
const std::vector<label_type> get_labels ( const std::vector<result_type> get_labels (
) const; ) const;
/*! /*!
ensures ensures
...@@ -124,7 +124,7 @@ namespace dlib ...@@ -124,7 +124,7 @@ namespace dlib
this object) this object)
!*/ !*/
label_type operator() ( result_type operator() (
const sample_type& sample const sample_type& sample
) const ) const
/*! /*!
......
...@@ -31,12 +31,12 @@ namespace dlib ...@@ -31,12 +31,12 @@ namespace dlib
{ {
public: public:
typedef typename one_vs_one_trainer::label_type label_type; typedef typename one_vs_one_trainer::label_type result_type;
typedef typename one_vs_one_trainer::sample_type sample_type; typedef typename one_vs_one_trainer::sample_type sample_type;
typedef typename one_vs_one_trainer::scalar_type scalar_type; typedef typename one_vs_one_trainer::scalar_type scalar_type;
typedef typename one_vs_one_trainer::mem_manager_type mem_manager_type; typedef typename one_vs_one_trainer::mem_manager_type mem_manager_type;
typedef std::map<unordered_pair<label_type>, any_decision_function<sample_type, scalar_type> > binary_function_table; typedef std::map<unordered_pair<result_type>, any_decision_function<sample_type, scalar_type> > binary_function_table;
one_vs_one_decision_function() :num_classes(0) {} one_vs_one_decision_function() :num_classes(0) {}
...@@ -46,7 +46,7 @@ namespace dlib ...@@ -46,7 +46,7 @@ namespace dlib
{ {
#ifdef ENABLE_ASSERTS #ifdef ENABLE_ASSERTS
{ {
const std::vector<unordered_pair<label_type> > missing_pairs = find_missing_pairs(dfs_); const std::vector<unordered_pair<result_type> > missing_pairs = find_missing_pairs(dfs_);
if (missing_pairs.size() != 0) if (missing_pairs.size() != 0)
{ {
std::ostringstream sout; std::ostringstream sout;
...@@ -65,7 +65,7 @@ namespace dlib ...@@ -65,7 +65,7 @@ namespace dlib
#endif #endif
// figure out how many labels are covered by this set of binary decision functions // figure out how many labels are covered by this set of binary decision functions
std::set<label_type> labels; std::set<result_type> labels;
for (typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i) for (typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i)
{ {
labels.insert(i->first.first); labels.insert(i->first.first);
...@@ -80,16 +80,16 @@ namespace dlib ...@@ -80,16 +80,16 @@ namespace dlib
return dfs; return dfs;
} }
const std::vector<label_type> get_labels ( const std::vector<result_type> get_labels (
) const ) const
{ {
std::set<label_type> labels; std::set<result_type> labels;
for (typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i) for (typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i)
{ {
labels.insert(i->first.first); labels.insert(i->first.first);
labels.insert(i->first.second); labels.insert(i->first.second);
} }
return std::vector<label_type>(labels.begin(), labels.end()); return std::vector<result_type>(labels.begin(), labels.end());
} }
...@@ -109,7 +109,7 @@ namespace dlib ...@@ -109,7 +109,7 @@ namespace dlib
return num_classes; return num_classes;
} }
label_type operator() ( result_type operator() (
const sample_type& sample const sample_type& sample
) const ) const
{ {
...@@ -119,7 +119,7 @@ namespace dlib ...@@ -119,7 +119,7 @@ namespace dlib
<< "\n\t this: " << this << "\n\t this: " << this
); );
std::map<label_type,int> votes; std::map<result_type,int> votes;
// run all the classifiers over the sample // run all the classifiers over the sample
for(typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i) for(typename binary_function_table::const_iterator i = dfs.begin(); i != dfs.end(); ++i)
...@@ -133,9 +133,9 @@ namespace dlib ...@@ -133,9 +133,9 @@ namespace dlib
} }
// now figure out who had the most votes // now figure out who had the most votes
label_type best_label = label_type(); result_type best_label = result_type();
int best_votes = 0; int best_votes = 0;
for (typename std::map<label_type,int>::iterator i = votes.begin(); i != votes.end(); ++i) for (typename std::map<result_type,int>::iterator i = votes.begin(); i != votes.end(); ++i)
{ {
if (i->second > best_votes) if (i->second > best_votes)
{ {
...@@ -172,10 +172,10 @@ namespace dlib ...@@ -172,10 +172,10 @@ namespace dlib
try try
{ {
type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp; type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp;
typedef typename T::label_type label_type; typedef typename T::label_type result_type;
typedef typename T::sample_type sample_type; typedef typename T::sample_type sample_type;
typedef typename T::scalar_type scalar_type; typedef typename T::scalar_type scalar_type;
typedef std::map<unordered_pair<label_type>, any_decision_function<sample_type, scalar_type> > binary_function_table; typedef std::map<unordered_pair<result_type>, any_decision_function<sample_type, scalar_type> > binary_function_table;
const unsigned long version = 1; const unsigned long version = 1;
serialize(version, out); serialize(version, out);
...@@ -246,7 +246,7 @@ namespace dlib ...@@ -246,7 +246,7 @@ namespace dlib
try try
{ {
type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp; type_safe_union<DF1,DF2,DF3,DF4,DF5,DF6,DF7,DF8,DF9,DF10> temp;
typedef typename T::label_type label_type; typedef typename T::label_type result_type;
typedef typename T::sample_type sample_type; typedef typename T::sample_type sample_type;
typedef typename T::scalar_type scalar_type; typedef typename T::scalar_type scalar_type;
typedef impl::copy_to_df_helper<sample_type, scalar_type> copy_to; typedef impl::copy_to_df_helper<sample_type, scalar_type> copy_to;
...@@ -260,10 +260,10 @@ namespace dlib ...@@ -260,10 +260,10 @@ namespace dlib
unsigned long size; unsigned long size;
deserialize(size, in); deserialize(size, in);
typedef std::map<unordered_pair<label_type>, any_decision_function<sample_type, scalar_type> > binary_function_table; typedef std::map<unordered_pair<result_type>, any_decision_function<sample_type, scalar_type> > binary_function_table;
binary_function_table dfs; binary_function_table dfs;
unordered_pair<label_type> p; unordered_pair<result_type> p;
for (unsigned long i = 0; i < size; ++i) for (unsigned long i = 0; i < size; ++i)
{ {
deserialize(p, in); deserialize(p, in);
......
...@@ -53,12 +53,12 @@ namespace dlib ...@@ -53,12 +53,12 @@ namespace dlib
!*/ !*/
public: public:
typedef typename one_vs_one_trainer::label_type label_type; typedef typename one_vs_one_trainer::label_type result_type;
typedef typename one_vs_one_trainer::sample_type sample_type; typedef typename one_vs_one_trainer::sample_type sample_type;
typedef typename one_vs_one_trainer::scalar_type scalar_type; typedef typename one_vs_one_trainer::scalar_type scalar_type;
typedef typename one_vs_one_trainer::mem_manager_type mem_manager_type; typedef typename one_vs_one_trainer::mem_manager_type mem_manager_type;
typedef std::map<unordered_pair<label_type>, any_decision_function<sample_type, scalar_type> > binary_function_table; typedef std::map<unordered_pair<result_type>, any_decision_function<sample_type, scalar_type> > binary_function_table;
one_vs_one_decision_function( one_vs_one_decision_function(
); );
...@@ -118,7 +118,7 @@ namespace dlib ...@@ -118,7 +118,7 @@ namespace dlib
receive a label of i->first.second receive a label of i->first.second
!*/ !*/
const std::vector<label_type> get_labels ( const std::vector<result_type> get_labels (
) const; ) const;
/*! /*!
ensures ensures
...@@ -135,7 +135,7 @@ namespace dlib ...@@ -135,7 +135,7 @@ namespace dlib
this object) this object)
!*/ !*/
label_type operator() ( result_type operator() (
const sample_type& sample const sample_type& sample
) const ) const
/*! /*!
......
...@@ -103,18 +103,15 @@ namespace dlib ...@@ -103,18 +103,15 @@ namespace dlib
typename in_sample_vector_type, typename in_sample_vector_type,
typename in_scalar_vector_type typename in_scalar_vector_type
> >
const matrix<typename dec_funct_type::scalar_type, 1, 2, typename dec_funct_type::mem_manager_type> const matrix<double,1,2> test_binary_decision_function_impl (
test_binary_decision_function_impl (
const dec_funct_type& dec_funct, const dec_funct_type& dec_funct,
const in_sample_vector_type& x_test, const in_sample_vector_type& x_test,
const in_scalar_vector_type& y_test const in_scalar_vector_type& y_test
) )
{ {
typedef typename dec_funct_type::scalar_type scalar_type;
typedef typename dec_funct_type::sample_type sample_type; typedef typename dec_funct_type::sample_type sample_type;
typedef typename dec_funct_type::mem_manager_type mem_manager_type; typedef typename dec_funct_type::mem_manager_type mem_manager_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type; typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_ASSERT( is_binary_classification_problem(x_test,y_test) == true, DLIB_ASSERT( is_binary_classification_problem(x_test,y_test) == true,
...@@ -156,9 +153,9 @@ namespace dlib ...@@ -156,9 +153,9 @@ namespace dlib
} }
matrix<scalar_type, 1, 2, mem_manager_type> res; matrix<double, 1, 2, mem_manager_type> res;
res(0) = (scalar_type)num_pos_correct/(scalar_type)(num_pos); res(0) = (double)num_pos_correct/(double)(num_pos);
res(1) = (scalar_type)num_neg_correct/(scalar_type)(num_neg); res(1) = (double)num_neg_correct/(double)(num_neg);
return res; return res;
} }
...@@ -167,8 +164,7 @@ namespace dlib ...@@ -167,8 +164,7 @@ namespace dlib
typename in_sample_vector_type, typename in_sample_vector_type,
typename in_scalar_vector_type typename in_scalar_vector_type
> >
const matrix<typename dec_funct_type::scalar_type, 1, 2, typename dec_funct_type::mem_manager_type> const matrix<double,1,2> test_binary_decision_function (
test_binary_decision_function (
const dec_funct_type& dec_funct, const dec_funct_type& dec_funct,
const in_sample_vector_type& x_test, const in_sample_vector_type& x_test,
const in_scalar_vector_type& y_test const in_scalar_vector_type& y_test
...@@ -186,7 +182,7 @@ namespace dlib ...@@ -186,7 +182,7 @@ namespace dlib
typename in_sample_vector_type, typename in_sample_vector_type,
typename in_scalar_vector_type typename in_scalar_vector_type
> >
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type> const matrix<double, 1, 2, typename trainer_type::mem_manager_type>
cross_validate_trainer_impl ( cross_validate_trainer_impl (
const trainer_type& trainer, const trainer_type& trainer,
const in_sample_vector_type& x, const in_sample_vector_type& x,
...@@ -194,7 +190,7 @@ namespace dlib ...@@ -194,7 +190,7 @@ namespace dlib
const long folds const long folds
) )
{ {
typedef typename trainer_type::scalar_type scalar_type; typedef typename in_scalar_vector_type::value_type scalar_type;
typedef typename trainer_type::sample_type sample_type; typedef typename trainer_type::sample_type sample_type;
typedef typename trainer_type::mem_manager_type mem_manager_type; typedef typename trainer_type::mem_manager_type mem_manager_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type; typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
...@@ -239,7 +235,7 @@ namespace dlib ...@@ -239,7 +235,7 @@ namespace dlib
long pos_idx = 0; long pos_idx = 0;
long neg_idx = 0; long neg_idx = 0;
matrix<scalar_type, 1, 2, mem_manager_type> res; matrix<double, 1, 2, mem_manager_type> res;
set_all_elements(res,0); set_all_elements(res,0);
for (long i = 0; i < folds; ++i) for (long i = 0; i < folds; ++i)
...@@ -314,7 +310,7 @@ namespace dlib ...@@ -314,7 +310,7 @@ namespace dlib
} // for (long i = 0; i < folds; ++i) } // for (long i = 0; i < folds; ++i)
return res/(scalar_type)folds; return res/(double)folds;
} }
template < template <
...@@ -322,7 +318,7 @@ namespace dlib ...@@ -322,7 +318,7 @@ namespace dlib
typename in_sample_vector_type, typename in_sample_vector_type,
typename in_scalar_vector_type typename in_scalar_vector_type
> >
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type> const matrix<double, 1, 2, typename trainer_type::mem_manager_type>
cross_validate_trainer ( cross_validate_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const in_sample_vector_type& x, const in_sample_vector_type& x,
......
...@@ -125,8 +125,7 @@ namespace dlib ...@@ -125,8 +125,7 @@ namespace dlib
typename in_sample_vector_type, typename in_sample_vector_type,
typename in_scalar_vector_type typename in_scalar_vector_type
> >
const matrix<typename trainer_type::scalar_type, 1, 2, typename trainer_type::mem_manager_type> const matrix<double,1,2> cross_validate_trainer (
cross_validate_trainer (
const trainer_type& trainer, const trainer_type& trainer,
const in_sample_vector_type& x, const in_sample_vector_type& x,
const in_scalar_vector_type& y, const in_scalar_vector_type& y,
...@@ -157,8 +156,7 @@ namespace dlib ...@@ -157,8 +156,7 @@ namespace dlib
typename in_sample_vector_type, typename in_sample_vector_type,
typename in_scalar_vector_type typename in_scalar_vector_type
> >
const matrix<typename dec_funct_type::scalar_type, 1, 2, typename dec_funct_type::mem_manager_type> const matrix<double,1,2> test_binary_decision_function (
test_binary_decision_function (
const dec_funct_type& dec_funct, const dec_funct_type& dec_funct,
const in_sample_vector_type& x_test, const in_sample_vector_type& x_test,
const in_scalar_vector_type& y_test const in_scalar_vector_type& y_test
......
...@@ -69,7 +69,7 @@ namespace ...@@ -69,7 +69,7 @@ namespace
trainer.set_trainer(krr_trainer<kernel_type>()); trainer.set_trainer(krr_trainer<kernel_type>());
randomize_samples(samples, labels); randomize_samples(samples, labels);
matrix<scalar_type> cv = cross_validate_multiclass_trainer(trainer, samples, labels, 4); matrix<double> cv = cross_validate_multiclass_trainer(trainer, samples, labels, 4);
dlog << LINFO << "confusion matrix: \n" << cv; dlog << LINFO << "confusion matrix: \n" << cv;
const scalar_type cv_accuracy = sum(diag(cv))/sum(cv); const scalar_type cv_accuracy = sum(diag(cv))/sum(cv);
......
...@@ -104,7 +104,7 @@ namespace ...@@ -104,7 +104,7 @@ namespace
print_spinner(); print_spinner();
typedef matrix<scalar_type,2,1> sample_type; typedef matrix<scalar_type,2,1> sample_type;
std::vector<sample_type> samples; std::vector<sample_type> samples, norm_samples;
std::vector<label_type> labels; std::vector<label_type> labels;
// First, get our labeled set of training data // First, get our labeled set of training data
...@@ -129,7 +129,7 @@ namespace ...@@ -129,7 +129,7 @@ namespace
trainer.set_trainer(poly_trainer, 1); trainer.set_trainer(poly_trainer, 1);
randomize_samples(samples, labels); randomize_samples(samples, labels);
matrix<scalar_type> res = cross_validate_multiclass_trainer(trainer, samples, labels, 2); matrix<double> res = cross_validate_multiclass_trainer(trainer, samples, labels, 2);
print_spinner(); print_spinner();
...@@ -140,6 +140,26 @@ namespace ...@@ -140,6 +140,26 @@ namespace
DLIB_TEST_MSG(ans == res, "res: \n" << res); DLIB_TEST_MSG(ans == res, "res: \n" << res);
// test using a normalized_function with a one_vs_all_decision_function
{
poly_trainer.set_kernel(poly_kernel(1.1, 1, 2));
trainer.set_trainer(poly_trainer, 1);
vector_normalizer<sample_type> normalizer;
normalizer.train(samples);
for (unsigned long i = 0; i < samples.size(); ++i)
norm_samples.push_back(normalizer(samples[i]));
normalized_function<one_vs_all_decision_function<ova_trainer> > ndf;
ndf.function = trainer.train(norm_samples, labels);
ndf.normalizer = normalizer;
DLIB_TEST(ndf(samples[0]) == labels[0]);
DLIB_TEST(ndf(samples[40]) == labels[40]);
DLIB_TEST(ndf(samples[90]) == labels[90]);
DLIB_TEST(ndf(samples[120]) == labels[120]);
poly_trainer.set_kernel(poly_kernel(0.1, 1, 2));
trainer.set_trainer(poly_trainer, 1);
print_spinner();
}
one_vs_all_decision_function<ova_trainer> df = trainer.train(samples, labels); one_vs_all_decision_function<ova_trainer> df = trainer.train(samples, labels);
DLIB_TEST(df.number_of_classes() == 3); DLIB_TEST(df.number_of_classes() == 3);
...@@ -205,7 +225,7 @@ namespace ...@@ -205,7 +225,7 @@ namespace
trainer.set_trainer(probabilistic(poly_trainer, 3), 1); trainer.set_trainer(probabilistic(poly_trainer, 3), 1);
randomize_samples(samples, labels); randomize_samples(samples, labels);
matrix<scalar_type> res = cross_validate_multiclass_trainer(trainer, samples, labels, 2); matrix<double> res = cross_validate_multiclass_trainer(trainer, samples, labels, 2);
print_spinner(); print_spinner();
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "tester.h" #include "tester.h"
#include <dlib/svm.h> #include <dlib/svm.h>
#include <dlib/statistics.h>
#include <vector> #include <vector>
#include <sstream> #include <sstream>
...@@ -104,7 +105,7 @@ namespace ...@@ -104,7 +105,7 @@ namespace
print_spinner(); print_spinner();
typedef matrix<scalar_type,2,1> sample_type; typedef matrix<scalar_type,2,1> sample_type;
std::vector<sample_type> samples; std::vector<sample_type> samples, norm_samples;
std::vector<label_type> labels; std::vector<label_type> labels;
// First, get our labeled set of training data // First, get our labeled set of training data
...@@ -129,7 +130,7 @@ namespace ...@@ -129,7 +130,7 @@ namespace
trainer.set_trainer(poly_trainer, 1, 2); trainer.set_trainer(poly_trainer, 1, 2);
randomize_samples(samples, labels); randomize_samples(samples, labels);
matrix<scalar_type> res = cross_validate_multiclass_trainer(trainer, samples, labels, 2); matrix<double> res = cross_validate_multiclass_trainer(trainer, samples, labels, 2);
print_spinner(); print_spinner();
...@@ -140,6 +141,29 @@ namespace ...@@ -140,6 +141,29 @@ namespace
DLIB_TEST_MSG(ans == res, "res: \n" << res); DLIB_TEST_MSG(ans == res, "res: \n" << res);
// test using a normalized_function with a one_vs_one_decision_function
{
poly_trainer.set_kernel(poly_kernel(1.1, 1, 2));
trainer.set_trainer(poly_trainer, 1, 2);
vector_normalizer<sample_type> normalizer;
normalizer.train(samples);
for (unsigned long i = 0; i < samples.size(); ++i)
norm_samples.push_back(normalizer(samples[i]));
normalized_function<one_vs_one_decision_function<ovo_trainer> > ndf;
ndf.function = trainer.train(norm_samples, labels);
ndf.normalizer = normalizer;
DLIB_TEST(ndf(samples[0]) == labels[0]);
DLIB_TEST(ndf(samples[40]) == labels[40]);
DLIB_TEST(ndf(samples[90]) == labels[90]);
DLIB_TEST(ndf(samples[120]) == labels[120]);
poly_trainer.set_kernel(poly_kernel(0.1, 1, 2));
trainer.set_trainer(poly_trainer, 1, 2);
print_spinner();
}
one_vs_one_decision_function<ovo_trainer> df = trainer.train(samples, labels); one_vs_one_decision_function<ovo_trainer> df = trainer.train(samples, labels);
DLIB_TEST(df.number_of_classes() == 3); DLIB_TEST(df.number_of_classes() == 3);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment