Commit b9a1e0c9 authored by Davis King's avatar Davis King

Added the multiclass_linear_decision_function object.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404182
parent 1e7659dc
......@@ -13,6 +13,8 @@
#include "../rand.h"
#include "../statistics.h"
#include "kernel_matrix.h"
#include "kernel.h"
#include "sparse_kernel.h"
namespace dlib
{
......@@ -791,6 +793,107 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename K,
typename result_type_ = typename K::scalar_type
>
struct multiclass_linear_decision_function
{
typedef result_type_ result_type;
typedef K kernel_type;
typedef typename K::scalar_type scalar_type;
typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
typedef matrix<scalar_type,0,0,mem_manager_type> scalar_matrix_type;
// You are getting a compiler error on this line because you supplied a non-linear kernel
// to the multiclass_linear_decision_function object. You have to use one of the linear
// kernels with this object.
COMPILE_TIME_ASSERT((is_same_type<K, linear_kernel<sample_type> >::value ||
is_same_type<K, sparse_linear_kernel<sample_type> >::value ));
scalar_matrix_type weights;
scalar_vector_type b;
std::vector<result_type> labels;
const std::vector<result_type>& get_labels(
) const { return labels; }
unsigned long number_of_classes (
) const { return labels.size(); }
result_type operator() (
const sample_type& x
) const
{
// Rather than doing something like, best_idx = index_of_max(weights*x-b)
// we do the following somewhat more complex thing because this supports
// both sparse and dense samples.
using sparse_vector::dot;
scalar_type best_val = dot(rowm(weights,0),x) - b(0);
unsigned long best_idx = 0;
for (unsigned long i = 0; i < labels.size(); ++i)
{
scalar_type temp = dot(rowm(weights,i),x) - b(i);
if (temp > best_val)
{
best_val = temp;
best_idx = i;
}
}
return labels[best_idx];
}
};
template <
typename K,
typename result_type_
>
void serialize (
const multiclass_linear_decision_function<K,result_type_>& item,
std::ostream& out
)
{
try
{
serialize(item.weights, out);
serialize(item.b, out);
serialize(item.labels, out);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while serializing object of type multiclass_linear_decision_function");
}
}
template <
typename K,
typename result_type_
>
void deserialize (
multiclass_linear_decision_function<K,result_type_>& item,
std::istream& in
)
{
try
{
deserialize(item.weights, in);
deserialize(item.b, in);
deserialize(item.labels, in);
}
catch (serialization_error& e)
{
throw serialization_error(e.info + "\n while deserializing object of type multiclass_linear_decision_function");
}
}
// ----------------------------------------------------------------------------------------
}
......
......@@ -862,6 +862,100 @@ namespace dlib
provides serialization support for projection_function
!*/
// ----------------------------------------------------------------------------------------
template <
typename K,
typename result_type_ = typename K::scalar_type
>
struct multiclass_linear_decision_function
{
/*!
REQUIREMENTS ON K
K must be either linear_kernel or sparse_linear_kernel.
WHAT THIS OBJECT REPRESENTS
This object represents a multiclass classifier built out of a set of
binary classifiers. Each binary classifier is used to vote for the
correct multiclass label using a one vs. all strategy. Therefore,
if you have N classes then there will be N binary classifiers inside
this object. Additionally, this object is linear in the sense that
each of these binary classifiers is a simple linear plane.
!*/
typedef result_type_ result_type;
typedef K kernel_type;
typedef typename K::scalar_type scalar_type;
typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
typedef matrix<scalar_type,0,0,mem_manager_type> scalar_matrix_type;
scalar_matrix_type weights;
scalar_vector_type b;
std::vector<result_type> labels;
const std::vector<result_type>& get_labels(
) const { return labels; }
/*!
ensures
- returns a vector containing all the labels which can be
predicted by this object.
!*/
unsigned long number_of_classes (
) const;
/*!
ensures
- returns get_labels().size()
(i.e. returns the number of different labels/classes predicted by
this object)
!*/
result_type operator() (
const sample_type& x
) const;
/*!
requires
- weights.size() > 0
- weights.nr() == number_of_classes() == b.size()
- if (x is a dense vector, i.e. a dlib::matrix) then
- is_vector(x) == true
- x.size() == weights.nc()
(i.e. it must be legal to multiply weights with x)
ensures
- Returns the predicted label for the x sample. In particular, it returns
the following:
labels[index_of_max(weights*x-b)]
!*/
};
template <
typename K,
typename result_type_
>
void serialize (
const multiclass_linear_decision_function<K,result_type_>& item,
std::ostream& out
);
/*!
provides serialization support for multiclass_linear_decision_function
!*/
template <
typename K,
typename result_type_
>
void deserialize (
multiclass_linear_decision_function<K,result_type_>& item,
std::istream& in
);
/*!
provides serialization support for multiclass_linear_decision_function
!*/
// ----------------------------------------------------------------------------------------
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment