Commit f7d97090 authored by Davis King's avatar Davis King

Added the distance_function object

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402387
parent 0c1c9f67
...@@ -231,6 +231,144 @@ namespace dlib ...@@ -231,6 +231,144 @@ namespace dlib
} }
} }
// ----------------------------------------------------------------------------------------
template <
typename K
>
struct distance_function
{
typedef typename K::scalar_type scalar_type;
typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
const scalar_vector_type alpha;
const scalar_type b;
const K kernel_function;
const sample_vector_type support_vectors;
distance_function (
) : b(0), kernel_function(K()) {}
distance_function (
const distance_function& d
) :
alpha(d.alpha),
b(d.b),
kernel_function(d.kernel_function),
support_vectors(d.support_vectors)
{}
distance_function (
const scalar_vector_type& alpha_,
const scalar_type& b_,
const K& kernel_function_,
const sample_vector_type& support_vectors_
) :
alpha(alpha_),
b(b_),
kernel_function(kernel_function_),
support_vectors(support_vectors_)
{}
distance_function& operator= (
const distance_function& d
)
{
if (this != &d)
{
const_cast<scalar_vector_type&>(alpha) = d.alpha;
const_cast<scalar_type&>(b) = d.b;
const_cast<K&>(kernel_function) = d.kernel_function;
const_cast<sample_vector_type&>(support_vectors) = d.support_vectors;
}
return *this;
}
scalar_type operator() (
const sample_type& x
) const
{
scalar_type temp = 0;
for (long i = 0; i < alpha.nr(); ++i)
temp += alpha(i) * kernel_function(x,support_vectors(i));
temp = b + kernel_function(x,x) - 2*temp;
if (temp > 0)
return std::sqrt(temp);
else
return 0;
}
scalar_type operator() (
const distance_function& x
) const
{
scalar_type temp = 0;
for (long i = 0; i < alpha.nr(); ++i)
for (long j = 0; j < x.alpha.nr(); ++j)
temp += alpha(i)*x.alpha(j) * kernel_function(support_vectors(i), x.support_vectors(j));
temp = b + x.b - 2*temp;
if (temp > 0)
return std::sqrt(temp);
else
return 0;
}
};
template <
typename K
>
void serialize (
const distance_function<K>& item,
std::ostream& out
)
{
try
{
serialize(item.alpha, out);
serialize(item.b, out);
serialize(item.kernel_function, out);
serialize(item.support_vectors, out);
}
catch (serialization_error e)
{
throw serialization_error(e.info + "\n while serializing object of type distance_function");
}
}
template <
typename K
>
void deserialize (
distance_function<K>& item,
std::istream& in
)
{
typedef typename K::scalar_type scalar_type;
typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
try
{
deserialize(const_cast<scalar_vector_type&>(item.alpha), in);
deserialize(const_cast<scalar_type&>(item.b), in);
deserialize(const_cast<K&>(item.kernel_function), in);
deserialize(const_cast<sample_vector_type&>(item.support_vectors), in);
}
catch (serialization_error e)
{
throw serialization_error(e.info + "\n while deserializing object of type distance_function");
}
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
} }
......
...@@ -93,7 +93,7 @@ namespace dlib ...@@ -93,7 +93,7 @@ namespace dlib
for (long i = 0; i < alpha.nr(); ++i) for (long i = 0; i < alpha.nr(); ++i)
temp += alpha(i) * kernel_function(x,support_vectors(i)); temp += alpha(i) * kernel_function(x,support_vectors(i));
returns temp - b; return temp - b;
} }
}; };
...@@ -225,6 +225,141 @@ namespace dlib ...@@ -225,6 +225,141 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <
typename K
>
struct distance_function
{
/*!
REQUIREMENTS ON K
K must be a kernel function object type as defined at the
top of dlib/svm/kernel_abstract.h
WHAT THIS OBJECT REPRESENTS
This object represents a point in kernel induced feature space.
You may use this object to find the distance from the point it
represents to points in input space.
!*/
typedef typename K::scalar_type scalar_type;
typedef typename K::sample_type sample_type;
typedef typename K::mem_manager_type mem_manager_type;
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
const scalar_vector_type alpha;
const scalar_type b;
const K kernel_function;
const sample_vector_type support_vectors;
distance_function (
);
/*!
ensures
- #b == 0
- #alpha.nr() == 0
- #support_vectors.nr() == 0
!*/
distance_function (
const distance_function& f
);
/*!
ensures
- #*this is a copy of f
!*/
distance_function (
const scalar_vector_type& alpha_,
const scalar_type& b_,
const K& kernel_function_,
const sample_vector_type& support_vectors_
) : alpha(alpha_), b(b_), kernel_function(kernel_function_), support_vectors(support_vectors_) {}
/*!
ensures
- populates the decision function with the given support vectors, weights(i.e. alphas),
b term, and kernel function.
!*/
distance_function& operator= (
const distance_function& d
);
/*!
ensures
- #*this is identical to d
- returns *this
!*/
scalar_type operator() (
const sample_type& x
) const
/*!
ensures
- Let O(x) represent the point x projected into kernel induced feature space.
- let c == sum alpha(i)*O(support_vectors(i)) == the point in kernel space that
this object represents.
- Then this object returns the distance between the points O(x) and c in kernel
space.
!*/
{
scalar_type temp = 0;
for (long i = 0; i < alpha.nr(); ++i)
temp += alpha(i) * kernel_function(x,support_vectors(i));
temp = b + kernel_function(x,x) - 2*temp;
if (temp > 0)
return std::sqrt(temp);
else
return 0;
}
scalar_type operator() (
const distance_function& x
) const
/*!
ensures
- returns the distance between the point in kernel space represented by *this and x.
!*/
{
scalar_type temp = 0;
for (long i = 0; i < alpha.nr(); ++i)
for (long j = 0; j < x.alpha.nr(); ++j)
temp += alpha(i)*x.alpha(j) * kernel_function(support_vectors(i), x.support_vectors(j));
temp = b + x.b - 2*temp;
if (temp > 0)
return std::sqrt(temp);
else
return 0;
}
};
template <
typename K
>
void serialize (
const distance_function<K>& item,
std::ostream& out
);
/*!
provides serialization support for distance_function
!*/
template <
typename K
>
void deserialize (
distance_function<K>& item,
std::istream& in
);
/*!
provides serialization support for distance_function
!*/
// ----------------------------------------------------------------------------------------
} }
#endif // DLIB_SVm_FUNCTION_ABSTRACT_ #endif // DLIB_SVm_FUNCTION_ABSTRACT_
......
...@@ -215,6 +215,16 @@ namespace dlib ...@@ -215,6 +215,16 @@ namespace dlib
item.bias_is_stale = true; item.bias_is_stale = true;
} }
distance_function<kernel_type> get_distance_function (
) const
{
refresh_bias();
return distance_function<kernel_type>(vector_to_matrix(alpha),
bias,
kernel,
vector_to_matrix(dictionary));
}
private: private:
void refresh_bias ( void refresh_bias (
......
...@@ -193,6 +193,16 @@ namespace dlib ...@@ -193,6 +193,16 @@ namespace dlib
- returns the number of "support vectors" in the dictionary. - returns the number of "support vectors" in the dictionary.
!*/ !*/
distance_function<kernel_type> get_distance_function (
) const;
/*!
ensures
- returns a distance function F that represents the point learned
by this object so far. I.e. it is the case that:
- for all x: F(x) == (*this)(x)
!*/
}; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment