Commit 27f9e6ef authored by Davis King's avatar Davis King

Modified the svm_nu_trainer so that it uses its kernel

matrix cache more efficiently.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402887
parent 3cb6611d
...@@ -203,6 +203,15 @@ namespace dlib ...@@ -203,6 +203,15 @@ namespace dlib
return (lookup(r) != -1); return (lookup(r) != -1);
} }
const scalar_type* col(long i) const
{
if (is_cached(i) == false)
add_col_to_cache(i);
return &cache(lookup(i),0);
}
const scalar_type* diag() const { return &diag_cache(0); }
inline scalar_type operator () ( inline scalar_type operator () (
long r, long r,
long c long c
...@@ -223,22 +232,29 @@ namespace dlib ...@@ -223,22 +232,29 @@ namespace dlib
} }
else else
{ {
// if the lookup table is pointing to cache(next,*) then clear lookup(next) add_col_to_cache(c);
if (rlookup(next) != -1) return cache(lookup(c),r);
lookup(rlookup(next)) = -1; }
}
// make the lookup table os that it says c is now cached at the spot indicated by next private:
lookup(c) = next; void add_col_to_cache(
rlookup(next) = c; long c
) const
{
// if the lookup table is pointing to cache(next,*) then clear lookup(next)
if (rlookup(next) != -1)
lookup(rlookup(next)) = -1;
// compute this column in the kernel matrix and store it in the cache // make the lookup table so that it says c is now cached at the spot indicated by next
for (long i = 0; i < cache.nc(); ++i) lookup(c) = next;
cache(next,i) = y(c)*y(i)*kernel_function(x(c),x(i)); rlookup(next) = c;
scalar_type val = cache(next,r); // compute this column in the kernel matrix and store it in the cache
next = (next + 1)%cache.nr(); for (long i = 0; i < cache.nc(); ++i)
return val; cache(next,i) = y(c)*y(i)*kernel_function(x(c),x(i));
}
next = (next + 1)%cache.nr();
} }
}; };
...@@ -1050,14 +1066,17 @@ namespace dlib ...@@ -1050,14 +1066,17 @@ namespace dlib
set_initial_alpha(y, nu, alpha); set_initial_alpha(y, nu, alpha);
set_all_elements(df, 0);
// initialize df. Compute df = Q*alpha // initialize df. Compute df = Q*alpha
for (long r = 0; r < df.nr(); ++r) for (long r = 0; r < df.nr(); ++r)
{ {
df(r) = 0; if (alpha(r) != 0)
for (long c = 0; c < alpha.nr(); ++c)
{ {
if (alpha(c) != 0) const scalar_type* Q_r = Q.col(r);
df(r) += Q(c,r)*alpha(c); for (long c = 0; c < alpha.nr(); ++c)
{
df(c) += alpha(r)*Q_r[c];
}
} }
} }
...@@ -1073,8 +1092,12 @@ namespace dlib ...@@ -1073,8 +1092,12 @@ namespace dlib
// update the df vector now that we have modified alpha(i) and alpha(j) // update the df vector now that we have modified alpha(i) and alpha(j)
scalar_type delta_alpha_i = alpha(i) - old_alpha_i; scalar_type delta_alpha_i = alpha(i) - old_alpha_i;
scalar_type delta_alpha_j = alpha(j) - old_alpha_j; scalar_type delta_alpha_j = alpha(j) - old_alpha_j;
const scalar_type* Q_i = Q.col(i);
const scalar_type* Q_j = Q.col(j);
for(long k = 0; k < df.nr(); ++k) for(long k = 0; k < df.nr(); ++k)
df(k) += Q(k,i)*delta_alpha_i + Q(k,j)*delta_alpha_j; df(k) += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
} }
...@@ -1253,6 +1276,18 @@ namespace dlib ...@@ -1253,6 +1276,18 @@ namespace dlib
scalar_type bp = -numeric_limits<scalar_type>::infinity(); scalar_type bp = -numeric_limits<scalar_type>::infinity();
scalar_type bn = -numeric_limits<scalar_type>::infinity(); scalar_type bn = -numeric_limits<scalar_type>::infinity();
// As a speed hack, pull out pointers to the columns of the
// kernel matrix we will be using below rather than accessing
// them through the Q(r,c) syntax.
const scalar_type* Q_ip = 0;
const scalar_type* Q_in = 0;
const scalar_type* Q_diag = Q.diag();
if (ip != -1)
Q_ip = Q.col(ip);
if (in != -1)
Q_in = Q.col(in);
// now we need to find the minimum jp and jn indices // now we need to find the minimum jp and jn indices
for (long j = 0; j < alpha.nr(); ++j) for (long j = 0; j < alpha.nr(); ++j)
{ {
...@@ -1264,10 +1299,10 @@ namespace dlib ...@@ -1264,10 +1299,10 @@ namespace dlib
if (-df(j) < Mp) if (-df(j) < Mp)
Mp = -df(j); Mp = -df(j);
if (b > 0 && (Q.is_cached(j) || b > bp || jp == -1 )) if (b > 0)
{ {
bp = b; bp = b;
scalar_type a = Q(ip,ip) + Q(j,j) - 2*Q(j,ip); scalar_type a = Q_ip[ip] + Q_diag[j] - 2*Q_ip[j];
if (a <= 0) if (a <= 0)
a = tau; a = tau;
scalar_type temp = -b*b/a; scalar_type temp = -b*b/a;
...@@ -1287,10 +1322,10 @@ namespace dlib ...@@ -1287,10 +1322,10 @@ namespace dlib
if (df(j) < Mn) if (df(j) < Mn)
Mn = df(j); Mn = df(j);
if (b > 0 && (Q.is_cached(j) || b > bn || jn == -1 )) if (b > 0)
{ {
bn = b; bn = b;
scalar_type a = Q(in,in) + Q(j,j) - 2*Q(j,in); scalar_type a = Q_in[in] + Q_diag[j] - 2*Q_in[j];
if (a <= 0) if (a <= 0)
a = tau; a = tau;
scalar_type temp = -b*b/a; scalar_type temp = -b*b/a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment