Commit 543e289f authored by Davis King's avatar Davis King

Improved the distance_function object by turning it into a properly encapsulated class

rather than just a simple struct.  I also added overloaded +, -, *, and / operators
for this object so you can do the kind of arithmetic you would expect on an object
which represents a point in a vector space.  This breaks backwards compatibility
with the previous interface though as the member variables are now private.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404124
parent f25b5369
...@@ -339,8 +339,9 @@ namespace dlib ...@@ -339,8 +339,9 @@ namespace dlib
template < template <
typename K typename K
> >
struct distance_function class distance_function
{ {
public:
typedef K kernel_type; typedef K kernel_type;
typedef typename K::scalar_type scalar_type; typedef typename K::scalar_type scalar_type;
typedef typename K::sample_type sample_type; typedef typename K::sample_type sample_type;
...@@ -349,14 +350,43 @@ namespace dlib ...@@ -349,14 +350,43 @@ namespace dlib
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type; typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type; typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
scalar_vector_type alpha;
scalar_type b;
K kernel_function;
sample_vector_type basis_vectors;
distance_function ( distance_function (
) : b(0), kernel_function(K()) {} ) : b(0), kernel_function(K()) {}
explicit distance_function (
const kernel_type& kern
) : b(0), kernel_function(kern) {}
distance_function (
const kernel_type& kern,
const sample_type& samp
) :
alpha(ones_matrix<scalar_type>(1,1)),
b(kern(samp,samp)),
kernel_function(kern)
{
basis_vectors.set_size(1,1);
basis_vectors(0) = samp;
}
distance_function (
const decision_function<K>& f
) :
alpha(f.alpha),
b(trans(f.alpha)*kernel_matrix(f.kernel_function,f.basis_vectors)*f.alpha),
kernel_function(f.kernel_function),
basis_vectors(f.basis_vectors)
{
// make sure requires clause is not broken
DLIB_ASSERT(f.alpha.size() == f.basis_vectors.size(),
"\t distance_function(f)"
<< "\n\t The supplied decision_function is invalid."
<< "\n\t f.alpha.size(): " << f.alpha.size()
<< "\n\t f.basis_vectors.size(): " << f.basis_vectors.size()
);
}
distance_function ( distance_function (
const distance_function& d const distance_function& d
) : ) :
...@@ -364,7 +394,8 @@ namespace dlib ...@@ -364,7 +394,8 @@ namespace dlib
b(d.b), b(d.b),
kernel_function(d.kernel_function), kernel_function(d.kernel_function),
basis_vectors(d.basis_vectors) basis_vectors(d.basis_vectors)
{} {
}
distance_function ( distance_function (
const scalar_vector_type& alpha_, const scalar_vector_type& alpha_,
...@@ -376,7 +407,46 @@ namespace dlib ...@@ -376,7 +407,46 @@ namespace dlib
b(b_), b(b_),
kernel_function(kernel_function_), kernel_function(kernel_function_),
basis_vectors(basis_vectors_) basis_vectors(basis_vectors_)
{} {
// make sure requires clause is not broken
DLIB_ASSERT(alpha_.size() == basis_vectors_.size(),
"\t distance_function()"
<< "\n\t The supplied arguments are invalid."
<< "\n\t alpha_.size(): " << alpha_.size()
<< "\n\t basis_vectors_.size(): " << basis_vectors_.size()
);
}
distance_function (
const scalar_vector_type& alpha_,
const K& kernel_function_,
const sample_vector_type& basis_vectors_
) :
alpha(alpha_),
b(trans(alpha)*kernel_matrix(kernel_function_,basis_vectors_)*alpha),
kernel_function(kernel_function_),
basis_vectors(basis_vectors_)
{
// make sure requires clause is not broken
DLIB_ASSERT(alpha_.size() == basis_vectors_.size(),
"\t distance_function()"
<< "\n\t The supplied arguments are invalid."
<< "\n\t alpha_.size(): " << alpha_.size()
<< "\n\t basis_vectors_.size(): " << basis_vectors_.size()
);
}
const scalar_vector_type& get_alpha (
) const { return alpha; }
const scalar_type& get_squared_norm (
) const { return b; }
const K& get_kernel(
) const { return kernel_function; }
const sample_vector_type& get_basis_vectors (
) const { return basis_vectors; }
distance_function& operator= ( distance_function& operator= (
const distance_function& d const distance_function& d
...@@ -422,8 +492,76 @@ namespace dlib ...@@ -422,8 +492,76 @@ namespace dlib
else else
return 0; return 0;
} }
distance_function operator* (
const scalar_type& val
) const
{
return distance_function(val*alpha,
val*val*b,
kernel_function,
basis_vectors);
}
distance_function operator/ (
const scalar_type& val
) const
{
return distance_function(alpha/val,
b/val/val,
kernel_function,
basis_vectors);
}
distance_function operator+ (
const distance_function& rhs
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(get_kernel() == rhs.get_kernel(),
"\t distance_function distance_function::operator+()"
<< "\n\t You can only add two distance_functions together if they use the same kernel."
);
return distance_function(join_cols(alpha, rhs.alpha),
b + rhs.b + 2*trans(alpha)*kernel_matrix(kernel_function,basis_vectors,rhs.basis_vectors)*rhs.alpha,
kernel_function,
join_cols(basis_vectors, rhs.basis_vectors));
}
distance_function operator- (
const distance_function& rhs
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(get_kernel() == rhs.get_kernel(),
"\t distance_function distance_function::operator-()"
<< "\n\t You can only subtract two distance_functions if they use the same kernel."
);
return distance_function(join_cols(alpha, -rhs.alpha),
b + rhs.b - 2*trans(alpha)*kernel_matrix(kernel_function,basis_vectors,rhs.basis_vectors)*rhs.alpha,
kernel_function,
join_cols(basis_vectors, rhs.basis_vectors));
}
private:
scalar_vector_type alpha;
scalar_type b;
K kernel_function;
sample_vector_type basis_vectors;
}; };
template <
typename K
>
distance_function<K> operator* (
const typename K::scalar_type& val,
const distance_function<K>& df
) { return df*val; }
template < template <
typename K typename K
> >
......
This diff is collapsed.
...@@ -247,19 +247,22 @@ namespace dlib ...@@ -247,19 +247,22 @@ namespace dlib
distance_function<kernel_type> get_distance_function ( distance_function<kernel_type> get_distance_function (
) const ) const
{ {
distance_function<kernel_type> temp;
temp.kernel_function = kernel;
if (samples_seen > 0) if (samples_seen > 0)
{ {
temp.b = squared_norm(); typename distance_function<kernel_type>::sample_vector_type temp_basis_vectors;
temp.basis_vectors.set_size(1); typename distance_function<kernel_type>::scalar_vector_type temp_alpha;
temp.basis_vectors(0) = w;
temp.alpha.set_size(1); temp_basis_vectors.set_size(1);
temp.alpha(0) = alpha; temp_basis_vectors(0) = w;
} temp_alpha.set_size(1);
temp_alpha(0) = alpha;
return temp; return distance_function<kernel_type>(temp_alpha, squared_norm(), kernel, temp_basis_vectors);
}
else
{
return distance_function<kernel_type>(kernel);
}
} }
private: private:
...@@ -576,12 +579,11 @@ namespace dlib ...@@ -576,12 +579,11 @@ namespace dlib
distance_function<kernel_type> get_distance_function ( distance_function<kernel_type> get_distance_function (
) const ) const
{ {
distance_function<kernel_type> temp;
temp.kernel_function = kernel;
if (samples_seen > 0) if (samples_seen > 0)
{ {
temp.b = squared_norm(); typename distance_function<kernel_type>::sample_vector_type temp_basis_vectors;
typename distance_function<kernel_type>::scalar_vector_type temp_alpha;
// What we are doing here needs a bit of explanation. The w vector // What we are doing here needs a bit of explanation. The w vector
// has an implicit extra dimension tacked on to it with the value of w_extra. // has an implicit extra dimension tacked on to it with the value of w_extra.
...@@ -595,27 +597,30 @@ namespace dlib ...@@ -595,27 +597,30 @@ namespace dlib
if (std::abs(w_extra) > std::numeric_limits<scalar_type>::epsilon()) if (std::abs(w_extra) > std::numeric_limits<scalar_type>::epsilon())
{ {
scale = (x_extra/w_extra); scale = (x_extra/w_extra);
temp.basis_vectors.set_size(1); temp_basis_vectors.set_size(1);
temp.alpha.set_size(1); temp_alpha.set_size(1);
temp.basis_vectors(0) = w*scale; temp_basis_vectors(0) = w*scale;
temp.alpha(0) = alpha/scale; temp_alpha(0) = alpha/scale;
} }
else else
{ {
// In this case w_extra is zero. So the only way we can get the same // In this case w_extra is zero. So the only way we can get the same
// thing in the output basis vector set is by using two vectors // thing in the output basis vector set is by using two vectors
temp.basis_vectors.set_size(2); temp_basis_vectors.set_size(2);
temp.alpha.set_size(2); temp_alpha.set_size(2);
temp.basis_vectors(0) = 2*w; temp_basis_vectors(0) = 2*w;
temp.alpha(0) = alpha; temp_alpha(0) = alpha;
temp.basis_vectors(1) = w; temp_basis_vectors(1) = w;
temp.alpha(1) = -alpha; temp_alpha(1) = -alpha;
} }
return distance_function<kernel_type>(temp_alpha, squared_norm(), kernel, temp_basis_vectors);
}
else
{
return distance_function<kernel_type>(kernel);
} }
return temp;
} }
private: private:
...@@ -877,19 +882,22 @@ namespace dlib ...@@ -877,19 +882,22 @@ namespace dlib
distance_function<kernel_type> get_distance_function ( distance_function<kernel_type> get_distance_function (
) const ) const
{ {
distance_function<kernel_type> temp;
temp.kernel_function = kernel;
if (samples_seen > 0) if (samples_seen > 0)
{ {
temp.b = squared_norm(); typename distance_function<kernel_type>::sample_vector_type temp_basis_vectors;
temp.basis_vectors.set_size(1); typename distance_function<kernel_type>::scalar_vector_type temp_alpha;
temp.basis_vectors(0) = sample_type(w.begin(), w.end());
temp.alpha.set_size(1); temp_basis_vectors.set_size(1);
temp.alpha(0) = alpha; temp_basis_vectors(0) = sample_type(w.begin(), w.end());
} temp_alpha.set_size(1);
temp_alpha(0) = alpha;
return temp; return distance_function<kernel_type>(temp_alpha, squared_norm(), kernel, temp_basis_vectors);
}
else
{
return distance_function<kernel_type>(kernel);
}
} }
private: private:
...@@ -1201,12 +1209,10 @@ namespace dlib ...@@ -1201,12 +1209,10 @@ namespace dlib
distance_function<kernel_type> get_distance_function ( distance_function<kernel_type> get_distance_function (
) const ) const
{ {
distance_function<kernel_type> temp;
temp.kernel_function = kernel;
if (samples_seen > 0) if (samples_seen > 0)
{ {
temp.b = squared_norm(); typename distance_function<kernel_type>::sample_vector_type temp_basis_vectors;
typename distance_function<kernel_type>::scalar_vector_type temp_alpha;
// What we are doing here needs a bit of explanation. The w vector // What we are doing here needs a bit of explanation. The w vector
// has an implicit extra dimension tacked on to it with the value of w_extra. // has an implicit extra dimension tacked on to it with the value of w_extra.
...@@ -1220,29 +1226,33 @@ namespace dlib ...@@ -1220,29 +1226,33 @@ namespace dlib
if (std::abs(w_extra) > std::numeric_limits<scalar_type>::epsilon()) if (std::abs(w_extra) > std::numeric_limits<scalar_type>::epsilon())
{ {
scale = (x_extra/w_extra); scale = (x_extra/w_extra);
temp.basis_vectors.set_size(1); temp_basis_vectors.set_size(1);
temp.alpha.set_size(1); temp_alpha.set_size(1);
temp.basis_vectors(0) = sample_type(w.begin(), w.end()); temp_basis_vectors(0) = sample_type(w.begin(), w.end());
sparse_vector::scale_by(temp.basis_vectors(0), scale); sparse_vector::scale_by(temp_basis_vectors(0), scale);
temp.alpha(0) = alpha/scale; temp_alpha(0) = alpha/scale;
} }
else else
{ {
// In this case w_extra is zero. So the only way we can get the same // In this case w_extra is zero. So the only way we can get the same
// thing in the output basis vector set is by using two vectors // thing in the output basis vector set is by using two vectors
temp.basis_vectors.set_size(2); temp_basis_vectors.set_size(2);
temp.alpha.set_size(2); temp_alpha.set_size(2);
temp.basis_vectors(0) = sample_type(w.begin(), w.end()); temp_basis_vectors(0) = sample_type(w.begin(), w.end());
sparse_vector::scale_by(temp.basis_vectors(0), 2); sparse_vector::scale_by(temp_basis_vectors(0), 2);
temp.alpha(0) = alpha; temp_alpha(0) = alpha;
temp.basis_vectors(1) = sample_type(w.begin(), w.end()); temp_basis_vectors(1) = sample_type(w.begin(), w.end());
temp.alpha(1) = -alpha; temp_alpha(1) = -alpha;
} }
return distance_function<kernel_type>(temp_alpha, squared_norm(), kernel, temp_basis_vectors);
} }
else
{
return distance_function<kernel_type>(kernel);
}
return temp;
} }
private: private:
......
...@@ -248,7 +248,7 @@ namespace dlib ...@@ -248,7 +248,7 @@ namespace dlib
) const ) const
{ {
distance_function<offset_kernel<kernel_type> > df = w.get_distance_function(); distance_function<offset_kernel<kernel_type> > df = w.get_distance_function();
return decision_function<kernel_type>(df.alpha, -tau*sum(df.alpha), kernel, df.basis_vectors); return decision_function<kernel_type>(df.get_alpha(), -tau*sum(df.get_alpha()), kernel, df.get_basis_vectors());
} }
void swap ( void swap (
......
...@@ -122,7 +122,7 @@ namespace ...@@ -122,7 +122,7 @@ namespace
// projected onto // projected onto
DLIB_TEST_MSG(abs(df(test_point) - err) < 1e-10, abs(df(test_point) - err)); DLIB_TEST_MSG(abs(df(test_point) - err) < 1e-10, abs(df(test_point) - err));
// while we are at it make sure the squared norm in the distance function is right // while we are at it make sure the squared norm in the distance function is right
double df_error = abs(df.b - trans(df.alpha)*kernel_matrix(kern, samples)*df.alpha); double df_error = abs(df.get_squared_norm() - trans(df.get_alpha())*kernel_matrix(kern, samples)*df.get_alpha());
DLIB_TEST_MSG( df_error < 1e-10, df_error); DLIB_TEST_MSG( df_error < 1e-10, df_error);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment