Commit 543e289f authored by Davis King's avatar Davis King

Improved the distance_function object by turning it into a properly encapsulated class

rather than just a simple struct.  I also added overloaded +, -, *, and / operators
for this object so you can do the kind of arithmetic you would expect on an object
which represents a point in a vector space.  This breaks backwards compatibility
with the previous interface though as the member variables are now private.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%404124
parent f25b5369
......@@ -339,8 +339,9 @@ namespace dlib
template <
typename K
>
struct distance_function
class distance_function
{
public:
typedef K kernel_type;
typedef typename K::scalar_type scalar_type;
typedef typename K::sample_type sample_type;
......@@ -349,14 +350,43 @@ namespace dlib
typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
scalar_vector_type alpha;
scalar_type b;
K kernel_function;
sample_vector_type basis_vectors;
distance_function (
) : b(0), kernel_function(K()) {}
explicit distance_function (
const kernel_type& kern
) : b(0), kernel_function(kern) {}
distance_function (
const kernel_type& kern,
const sample_type& samp
) :
alpha(ones_matrix<scalar_type>(1,1)),
b(kern(samp,samp)),
kernel_function(kern)
{
basis_vectors.set_size(1,1);
basis_vectors(0) = samp;
}
distance_function (
const decision_function<K>& f
) :
alpha(f.alpha),
b(trans(f.alpha)*kernel_matrix(f.kernel_function,f.basis_vectors)*f.alpha),
kernel_function(f.kernel_function),
basis_vectors(f.basis_vectors)
{
// make sure requires clause is not broken
DLIB_ASSERT(f.alpha.size() == f.basis_vectors.size(),
"\t distance_function(f)"
<< "\n\t The supplied decision_function is invalid."
<< "\n\t f.alpha.size(): " << f.alpha.size()
<< "\n\t f.basis_vectors.size(): " << f.basis_vectors.size()
);
}
distance_function (
const distance_function& d
) :
......@@ -364,7 +394,8 @@ namespace dlib
b(d.b),
kernel_function(d.kernel_function),
basis_vectors(d.basis_vectors)
{}
{
}
distance_function (
const scalar_vector_type& alpha_,
......@@ -376,7 +407,46 @@ namespace dlib
b(b_),
kernel_function(kernel_function_),
basis_vectors(basis_vectors_)
{}
{
// make sure requires clause is not broken
DLIB_ASSERT(alpha_.size() == basis_vectors_.size(),
"\t distance_function()"
<< "\n\t The supplied arguments are invalid."
<< "\n\t alpha_.size(): " << alpha_.size()
<< "\n\t basis_vectors_.size(): " << basis_vectors_.size()
);
}
distance_function (
const scalar_vector_type& alpha_,
const K& kernel_function_,
const sample_vector_type& basis_vectors_
) :
alpha(alpha_),
b(trans(alpha)*kernel_matrix(kernel_function_,basis_vectors_)*alpha),
kernel_function(kernel_function_),
basis_vectors(basis_vectors_)
{
// make sure requires clause is not broken
DLIB_ASSERT(alpha_.size() == basis_vectors_.size(),
"\t distance_function()"
<< "\n\t The supplied arguments are invalid."
<< "\n\t alpha_.size(): " << alpha_.size()
<< "\n\t basis_vectors_.size(): " << basis_vectors_.size()
);
}
const scalar_vector_type& get_alpha (
) const { return alpha; }
const scalar_type& get_squared_norm (
) const { return b; }
const K& get_kernel(
) const { return kernel_function; }
const sample_vector_type& get_basis_vectors (
) const { return basis_vectors; }
distance_function& operator= (
const distance_function& d
......@@ -422,8 +492,76 @@ namespace dlib
else
return 0;
}
distance_function operator* (
const scalar_type& val
) const
{
return distance_function(val*alpha,
val*val*b,
kernel_function,
basis_vectors);
}
distance_function operator/ (
const scalar_type& val
) const
{
return distance_function(alpha/val,
b/val/val,
kernel_function,
basis_vectors);
}
distance_function operator+ (
const distance_function& rhs
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(get_kernel() == rhs.get_kernel(),
"\t distance_function distance_function::operator+()"
<< "\n\t You can only add two distance_functions together if they use the same kernel."
);
return distance_function(join_cols(alpha, rhs.alpha),
b + rhs.b + 2*trans(alpha)*kernel_matrix(kernel_function,basis_vectors,rhs.basis_vectors)*rhs.alpha,
kernel_function,
join_cols(basis_vectors, rhs.basis_vectors));
}
distance_function operator- (
const distance_function& rhs
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(get_kernel() == rhs.get_kernel(),
"\t distance_function distance_function::operator-()"
<< "\n\t You can only subtract two distance_functions if they use the same kernel."
);
return distance_function(join_cols(alpha, -rhs.alpha),
b + rhs.b - 2*trans(alpha)*kernel_matrix(kernel_function,basis_vectors,rhs.basis_vectors)*rhs.alpha,
kernel_function,
join_cols(basis_vectors, rhs.basis_vectors));
}
private:
scalar_vector_type alpha;
scalar_type b;
K kernel_function;
sample_vector_type basis_vectors;
};
template <
typename K
>
distance_function<K> operator* (
const typename K::scalar_type& val,
const distance_function<K>& df
) { return df*val; }
template <
typename K
>
......
This diff is collapsed.
......@@ -247,19 +247,22 @@ namespace dlib
distance_function<kernel_type> get_distance_function (
) const
{
distance_function<kernel_type> temp;
temp.kernel_function = kernel;
if (samples_seen > 0)
{
temp.b = squared_norm();
temp.basis_vectors.set_size(1);
temp.basis_vectors(0) = w;
temp.alpha.set_size(1);
temp.alpha(0) = alpha;
}
typename distance_function<kernel_type>::sample_vector_type temp_basis_vectors;
typename distance_function<kernel_type>::scalar_vector_type temp_alpha;
temp_basis_vectors.set_size(1);
temp_basis_vectors(0) = w;
temp_alpha.set_size(1);
temp_alpha(0) = alpha;
return temp;
return distance_function<kernel_type>(temp_alpha, squared_norm(), kernel, temp_basis_vectors);
}
else
{
return distance_function<kernel_type>(kernel);
}
}
private:
......@@ -576,12 +579,11 @@ namespace dlib
distance_function<kernel_type> get_distance_function (
) const
{
distance_function<kernel_type> temp;
temp.kernel_function = kernel;
if (samples_seen > 0)
{
temp.b = squared_norm();
typename distance_function<kernel_type>::sample_vector_type temp_basis_vectors;
typename distance_function<kernel_type>::scalar_vector_type temp_alpha;
// What we are doing here needs a bit of explanation. The w vector
// has an implicit extra dimension tacked on to it with the value of w_extra.
......@@ -595,27 +597,30 @@ namespace dlib
if (std::abs(w_extra) > std::numeric_limits<scalar_type>::epsilon())
{
scale = (x_extra/w_extra);
temp.basis_vectors.set_size(1);
temp.alpha.set_size(1);
temp.basis_vectors(0) = w*scale;
temp.alpha(0) = alpha/scale;
temp_basis_vectors.set_size(1);
temp_alpha.set_size(1);
temp_basis_vectors(0) = w*scale;
temp_alpha(0) = alpha/scale;
}
else
{
// In this case w_extra is zero. So the only way we can get the same
// thing in the output basis vector set is by using two vectors
temp.basis_vectors.set_size(2);
temp.alpha.set_size(2);
temp.basis_vectors(0) = 2*w;
temp.alpha(0) = alpha;
temp.basis_vectors(1) = w;
temp.alpha(1) = -alpha;
temp_basis_vectors.set_size(2);
temp_alpha.set_size(2);
temp_basis_vectors(0) = 2*w;
temp_alpha(0) = alpha;
temp_basis_vectors(1) = w;
temp_alpha(1) = -alpha;
}
return distance_function<kernel_type>(temp_alpha, squared_norm(), kernel, temp_basis_vectors);
}
else
{
return distance_function<kernel_type>(kernel);
}
return temp;
}
private:
......@@ -877,19 +882,22 @@ namespace dlib
distance_function<kernel_type> get_distance_function (
) const
{
distance_function<kernel_type> temp;
temp.kernel_function = kernel;
if (samples_seen > 0)
{
temp.b = squared_norm();
temp.basis_vectors.set_size(1);
temp.basis_vectors(0) = sample_type(w.begin(), w.end());
temp.alpha.set_size(1);
temp.alpha(0) = alpha;
}
typename distance_function<kernel_type>::sample_vector_type temp_basis_vectors;
typename distance_function<kernel_type>::scalar_vector_type temp_alpha;
temp_basis_vectors.set_size(1);
temp_basis_vectors(0) = sample_type(w.begin(), w.end());
temp_alpha.set_size(1);
temp_alpha(0) = alpha;
return temp;
return distance_function<kernel_type>(temp_alpha, squared_norm(), kernel, temp_basis_vectors);
}
else
{
return distance_function<kernel_type>(kernel);
}
}
private:
......@@ -1201,12 +1209,10 @@ namespace dlib
distance_function<kernel_type> get_distance_function (
) const
{
distance_function<kernel_type> temp;
temp.kernel_function = kernel;
if (samples_seen > 0)
{
temp.b = squared_norm();
typename distance_function<kernel_type>::sample_vector_type temp_basis_vectors;
typename distance_function<kernel_type>::scalar_vector_type temp_alpha;
// What we are doing here needs a bit of explanation. The w vector
// has an implicit extra dimension tacked on to it with the value of w_extra.
......@@ -1220,29 +1226,33 @@ namespace dlib
if (std::abs(w_extra) > std::numeric_limits<scalar_type>::epsilon())
{
scale = (x_extra/w_extra);
temp.basis_vectors.set_size(1);
temp.alpha.set_size(1);
temp.basis_vectors(0) = sample_type(w.begin(), w.end());
sparse_vector::scale_by(temp.basis_vectors(0), scale);
temp.alpha(0) = alpha/scale;
temp_basis_vectors.set_size(1);
temp_alpha.set_size(1);
temp_basis_vectors(0) = sample_type(w.begin(), w.end());
sparse_vector::scale_by(temp_basis_vectors(0), scale);
temp_alpha(0) = alpha/scale;
}
else
{
// In this case w_extra is zero. So the only way we can get the same
// thing in the output basis vector set is by using two vectors
temp.basis_vectors.set_size(2);
temp.alpha.set_size(2);
temp.basis_vectors(0) = sample_type(w.begin(), w.end());
sparse_vector::scale_by(temp.basis_vectors(0), 2);
temp.alpha(0) = alpha;
temp.basis_vectors(1) = sample_type(w.begin(), w.end());
temp.alpha(1) = -alpha;
temp_basis_vectors.set_size(2);
temp_alpha.set_size(2);
temp_basis_vectors(0) = sample_type(w.begin(), w.end());
sparse_vector::scale_by(temp_basis_vectors(0), 2);
temp_alpha(0) = alpha;
temp_basis_vectors(1) = sample_type(w.begin(), w.end());
temp_alpha(1) = -alpha;
}
return distance_function<kernel_type>(temp_alpha, squared_norm(), kernel, temp_basis_vectors);
}
else
{
return distance_function<kernel_type>(kernel);
}
return temp;
}
private:
......
......@@ -248,7 +248,7 @@ namespace dlib
) const
{
distance_function<offset_kernel<kernel_type> > df = w.get_distance_function();
return decision_function<kernel_type>(df.alpha, -tau*sum(df.alpha), kernel, df.basis_vectors);
return decision_function<kernel_type>(df.get_alpha(), -tau*sum(df.get_alpha()), kernel, df.get_basis_vectors());
}
void swap (
......
......@@ -122,7 +122,7 @@ namespace
// projected onto
DLIB_TEST_MSG(abs(df(test_point) - err) < 1e-10, abs(df(test_point) - err));
// while we are at it make sure the squared norm in the distance function is right
double df_error = abs(df.b - trans(df.alpha)*kernel_matrix(kern, samples)*df.alpha);
double df_error = abs(df.get_squared_norm() - trans(df.get_alpha())*kernel_matrix(kern, samples)*df.get_alpha());
DLIB_TEST_MSG( df_error < 1e-10, df_error);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment