Commit d6874d83 authored by Davis King's avatar Davis King

Removed the last bit of code with any heritage from numerical recipes in C.

This was in some of the svd routines.  However, we already had a svd routine
that used a separate svd code that is better than the NRIC derived version.  So
that's what we use everywhere now.
parent 6f80e810
......@@ -28,398 +28,30 @@ namespace dlib
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
namespace nric
enum svd_u_mode
{
// This namespace contains stuff adapted from the algorithms
// described in the book Numerical Recipes in C
template <typename T>
inline T pythag(const T& a, const T& b)
{
T absa,absb;
absa=std::abs(a);
absb=std::abs(b);
if (absa > absb)
{
T val = absb/absa;
val *= val;
return absa*std::sqrt(1.0+val);
}
else
{
if (absb == 0.0)
{
return 0.0;
}
else
{
T val = absa/absb;
val *= val;
return absb*std::sqrt(1.0+val);
}
}
}
template <typename T>
inline T sign(const T& a, const T& b)
{
if (b < 0)
{
return -std::abs(a);
}
else
{
return std::abs(a);
}
}
template <
typename T,
long M, long N,
long wN, long wX,
long vN,
long rN, long rX,
typename MM1,
typename MM2,
typename MM3,
typename MM4,
typename L1,
typename L2,
typename L3,
typename L4
>
bool svdcmp(
matrix<T,M,N,MM1,L1>& a,
matrix<T,wN,wX,MM2,L2>& w,
matrix<T,vN,vN,MM3,L3>& v,
matrix<T,rN,rX,MM4,L4>& rv1
)
/*! ( this function is derived from the one in numerical recipes in C chapter 2.6)
requires
- w.nr() == a.nc()
- w.nc() == 1
- v.nr() == a.nc()
- v.nc() == a.nc()
- rv1.nr() == a.nc()
- rv1.nc() == 1
ensures
- computes the singular value decomposition of a
- let W be the matrix such that diag(W) == #w then:
- a == #a*W*trans(#v)
- trans(#a)*#a == identity matrix
- trans(#v)*#v == identity matrix
- #rv1 == some undefined value
- returns true for success and false for failure
!*/
{
DLIB_ASSERT(
w.nr() == a.nc() &&
w.nc() == 1 &&
v.nr() == a.nc() &&
v.nc() == a.nc() &&
rv1.nr() == a.nc() &&
rv1.nc() == 1, "");
COMPILE_TIME_ASSERT(wX == 0 || wX == 1);
COMPILE_TIME_ASSERT(rX == 0 || rX == 1);
const T one = 1.0;
const long max_iter = 300;
const long n = a.nc();
const long m = a.nr();
const T eps = std::numeric_limits<T>::epsilon();
long nm = 0, l = 0;
bool flag;
T anorm,c,f,g,h,s,scale,x,y,z;
g = 0.0;
scale = 0.0;
anorm = 0.0;
for (long i = 0; i < n; ++i)
{
l = i+1;
rv1(i) = scale*g;
g = s = scale = 0.0;
if (i < m)
{
for (long k = i; k < m; ++k)
scale += std::abs(a(k,i));
if (scale)
{
for (long k = i; k < m; ++k)
{
a(k,i) /= scale;
s += a(k,i)*a(k,i);
}
f = a(i,i);
g = -sign(std::sqrt(s),f);
h = f*g - s;
a(i,i) = f - g;
for (long j = l; j < n; ++j)
{
s = 0.0;
for (long k = i; k < m; ++k)
s += a(k,i)*a(k,j);
f = s/h;
for (long k = i; k < m; ++k)
a(k,j) += f*a(k,i);
}
for (long k = i; k < m; ++k)
a(k,i) *= scale;
}
}
w(i) = scale *g;
g=s=scale=0.0;
if (i < m && i < n-1)
{
for (long k = l; k < n; ++k)
scale += std::abs(a(i,k));
if (scale)
{
for (long k = l; k < n; ++k)
{
a(i,k) /= scale;
s += a(i,k)*a(i,k);
}
f = a(i,l);
g = -sign(std::sqrt(s),f);
h = f*g - s;
a(i,l) = f - g;
for (long k = l; k < n; ++k)
rv1(k) = a(i,k)/h;
for (long j = l; j < m; ++j)
{
s = 0.0;
for (long k = l; k < n; ++k)
s += a(j,k)*a(i,k);
for (long k = l; k < n; ++k)
a(j,k) += s*rv1(k);
}
for (long k = l; k < n; ++k)
a(i,k) *= scale;
}
}
anorm = std::max(anorm,(std::abs(w(i))+std::abs(rv1(i))));
}
for (long i = n-1; i >= 0; --i)
{
if (i < n-1)
{
if (g != 0)
{
for (long j = l; j < n ; ++j)
v(j,i) = (a(i,j)/a(i,l))/g;
for (long j = l; j < n; ++j)
{
s = 0.0;
for (long k = l; k < n; ++k)
s += a(i,k)*v(k,j);
for (long k = l; k < n; ++k)
v(k,j) += s*v(k,i);
}
}
for (long j = l; j < n; ++j)
v(i,j) = v(j,i) = 0.0;
}
v(i,i) = 1.0;
g = rv1(i);
l = i;
}
for (long i = std::min(m,n)-1; i >= 0; --i)
{
l = i + 1;
g = w(i);
for (long j = l; j < n; ++j)
a(i,j) = 0.0;
if (g != 0)
{
g = 1.0/g;
for (long j = l; j < n; ++j)
{
s = 0.0;
for (long k = l; k < m; ++k)
s += a(k,i)*a(k,j);
f=(s/a(i,i))*g;
for (long k = i; k < m; ++k)
a(k,j) += f*a(k,i);
}
for (long j = i; j < m; ++j)
a(j,i) *= g;
}
else
{
for (long j = i; j < m; ++j)
a(j,i) = 0.0;
}
++a(i,i);
}
for (long k = n-1; k >= 0; --k)
{
for (long its = 1; its <= max_iter; ++its)
{
flag = true;
for (l = k; l >= 1; --l)
{
nm = l - 1;
if (std::abs(rv1(l)) <= eps*anorm)
{
flag = false;
break;
}
if (std::abs(w(nm)) <= eps*anorm)
{
break;
}
}
if (flag)
{
c = 0.0;
s = 1.0;
for (long i = l; i <= k; ++i)
{
f = s*rv1(i);
rv1(i) = c*rv1(i);
if (std::abs(f) <= eps*anorm)
break;
g = w(i);
h = pythag(f,g);
w(i) = h;
h = 1.0/h;
c = g*h;
s = -f*h;
for (long j = 0; j < m; ++j)
{
y = a(j,nm);
z = a(j,i);
a(j,nm) = y*c + z*s;
a(j,i) = z*c - y*s;
}
}
}
z = w(k);
if (l == k)
{
if (z < 0.0)
{
w(k) = -z;
for (long j = 0; j < n; ++j)
v(j,k) = -v(j,k);
}
break;
}
if (its == max_iter)
return false;
x = w(l);
nm = k - 1;
y = w(nm);
g = rv1(nm);
h = rv1(k);
f = ((y-z)*(y+z) + (g-h)*(g+h))/(2.0*h*y);
g = pythag(f,one);
f = ((x-z)*(x+z) + h*((y/(f+sign(g,f)))-h))/x;
c = s = 1.0;
for (long j = l; j <= nm; ++j)
{
long i = j + 1;
g = rv1(i);
y = w(i);
h = s*g;
g = c*g;
z = pythag(f,h);
rv1(j) = z;
c = f/z;
s = h/z;
f = x*c + g*s;
g = g*c - x*s;
h = y*s;
y *= c;
for (long jj = 0; jj < n; ++jj)
{
x = v(jj,j);
z = v(jj,i);
v(jj,j) = x*c + z*s;
v(jj,i) = z*c - x*s;
}
z = pythag(f,h);
w(j) = z;
if (z != 0)
{
z = 1.0/z;
c = f*z;
s = h*z;
}
f = c*g + s*y;
x = c*y - s*g;
for (long jj = 0; jj < m; ++jj)
{
y = a(jj,j);
z = a(jj,i);
a(jj,j) = y*c + z*s;
a(jj,i) = z*c - y*s;
}
}
rv1(l) = 0.0;
rv1(k) = f;
w(k) = x;
}
}
return true;
}
// ------------------------------------------------------------------------------------
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
SVD_NO_U,
SVD_SKINNY_U,
SVD_FULL_U
};
template <
typename EXP,
long qN, long qX,
long uM,
long vN,
long uM, long uN,
long vM, long vN,
typename MM1,
typename MM2,
typename MM3,
typename L1
>
long svd2 (
bool withu,
long svd4 (
svd_u_mode u_mode,
bool withv,
const matrix_exp<EXP>& a,
matrix<typename EXP::type,uM,uM,MM1,L1>& u,
matrix<typename EXP::type,uM,uN,MM1,L1>& u,
matrix<typename EXP::type,qN,qX,MM2,L1>& q,
matrix<typename EXP::type,vN,vN,MM3,L1>& v
matrix<typename EXP::type,vM,vN,MM3,L1>& v
)
{
/*
......@@ -444,20 +76,13 @@ namespace dlib
and v an n x n orthogonal matrix. eps and tol are tolerance constants.
Suitable values are eps=1e-16 and tol=(1e-300)/eps if T == double.
If withu == false then u won't be computed and similarly if withv == false
then v won't be computed.
If u_mode == SVD_NO_U then u won't be computed and similarly if withv == false
then v won't be computed. If u_mode == SVD_SKINNY_U then u will be m x n instead of m x m.
*/
const long NR = matrix_exp<EXP>::NR;
const long NC = matrix_exp<EXP>::NC;
// make sure the output matrices have valid dimensions if they are statically dimensioned
COMPILE_TIME_ASSERT(qX == 0 || qX == 1);
COMPILE_TIME_ASSERT(NR == 0 || uM == 0 || NR == uM);
COMPILE_TIME_ASSERT(NC == 0 || vN == 0 || NC == vN);
DLIB_ASSERT(a.nr() >= a.nc(),
"\tconst matrix_exp svd2()"
"\tconst matrix_exp svd4()"
<< "\n\tYou have given an invalidly sized matrix"
<< "\n\ta.nr(): " << a.nr()
<< "\n\ta.nc(): " << a.nc()
......@@ -467,30 +92,33 @@ namespace dlib
typedef typename EXP::type T;
#ifdef DLIB_USE_LAPACK
matrix<typename EXP::type,0,0,MM1,L1> temp(a);
matrix<typename EXP::type,0,0,MM1,L1> temp(a), vtemp;
char jobu = 'A';
char jobvt = 'A';
if (withu == false)
if (u_mode == SVD_NO_U)
jobu = 'N';
else if (u_mode == SVD_SKINNY_U)
jobu = 'S';
if (withv == false)
jobvt = 'N';
int info;
if (withu == withv)
if (jobu == jobvt)
{
info = lapack::gesdd(jobu, temp, q, u, v);
info = lapack::gesdd(jobu, temp, q, u, vtemp);
}
else
{
info = lapack::gesvd(jobu, jobvt, temp, q, u, v);
info = lapack::gesvd(jobu, jobvt, temp, q, u, vtemp);
}
// pad q with zeros if it isn't the length we want
if (q.nr() < a.nc())
q = join_cols(q, zeros_matrix<T>(a.nc()-q.nr(),1));
v = trans(v);
if (withv)
v = trans(vtemp);
return info;
#else
......@@ -507,7 +135,10 @@ namespace dlib
matrix<T,qN,1,MM2> e(n,1);
q.set_size(n,1);
u.set_size(m,m);
if (u_mode == SVD_FULL_U)
u.set_size(m,m);
else
u.set_size(m,n);
retval = 0;
if (withv)
......@@ -625,33 +256,34 @@ namespace dlib
} /* end withv, parens added for clarity */
/* accumulation of left-hand transformations */
if (withu)
if (u_mode != SVD_NO_U)
{
for (i=n; i<m; i++)
for (i=n; i<u.nr(); i++)
{
for (j=n;j<m;j++)
for (j=n;j<u.nc();j++)
u(i,j) = 0.0;
u(i,i) = 1.0;
if (i < u.nc())
u(i,i) = 1.0;
}
}
if (withu)
if (u_mode != SVD_NO_U)
{
for (i=n-1; i>=0; i--)
{
l = i + 1;
g = q(i);
for (j=l; j<m; j++) /* upper limit was 'n' */
for (j=l; j<u.nc(); j++)
u(i,j) = 0.0;
if (g != 0.0)
{
h = u(i,i) * g;
for (j=l; j<m; j++)
{ /* upper limit was 'n' */
for (j=l; j<u.nc(); j++)
{
s = 0.0;
for (k=l; k<m; k++)
......@@ -674,7 +306,7 @@ namespace dlib
u(i,i) += 1.0;
} /* end i*/
} /* end withu, parens added for clarity */
}
/* diagonalization of the bidiagonal form */
eps *= x;
......@@ -715,7 +347,7 @@ cancellation:
c = g / h;
s = -f / h;
if (withu)
if (u_mode != SVD_NO_U)
{
for (j=0; j<m; j++)
{
......@@ -724,7 +356,7 @@ cancellation:
u(j,l1) = y * c + z * s;
u(j,i) = -y * s + z * c;
} /* end j */
} /* end withu, parens added for clarity */
}
} /* end i */
test_f_convergence:
......@@ -777,11 +409,14 @@ test_f_convergence:
} /* end withv, parens added for clarity */
q(i-1) = z = sqrt(f * f + h * h);
c = f / z;
s = h / z;
if (z != 0)
{
c = f / z;
s = h / z;
}
f = c * g + s * y;
x = -s * g + c * y;
if (withu)
if (u_mode != SVD_NO_U)
{
for (j=0; j<m; j++)
{
......@@ -790,7 +425,7 @@ test_f_convergence:
u(j,i-1) = y * c + z * s;
u(j,i) = -y * s + z * c;
} /* end j */
} /* end withu, parens added for clarity */
}
} /* end i */
e(l) = 0.0;
......@@ -817,6 +452,48 @@ convergence:
#endif
}
// ----------------------------------------------------------------------------------------
template <
typename EXP,
long qN, long qX,
long uM,
long vN,
typename MM1,
typename MM2,
typename MM3,
typename L1
>
long svd2 (
bool withu,
bool withv,
const matrix_exp<EXP>& a,
matrix<typename EXP::type,uM,uM,MM1,L1>& u,
matrix<typename EXP::type,qN,qX,MM2,L1>& q,
matrix<typename EXP::type,vN,vN,MM3,L1>& v
)
{
const long NR = matrix_exp<EXP>::NR;
const long NC = matrix_exp<EXP>::NC;
// make sure the output matrices have valid dimensions if they are statically dimensioned
COMPILE_TIME_ASSERT(qX == 0 || qX == 1);
COMPILE_TIME_ASSERT(NR == 0 || uM == 0 || NR == uM);
COMPILE_TIME_ASSERT(NC == 0 || vN == 0 || NC == vN);
DLIB_ASSERT(a.nr() >= a.nc(),
"\tconst matrix_exp svd4()"
<< "\n\tYou have given an invalidly sized matrix"
<< "\n\ta.nr(): " << a.nr()
<< "\n\ta.nc(): " << a.nc()
);
if (withu)
return svd4(SVD_FULL_U, withv, a,u,q,v);
else
return svd4(SVD_NO_U, withv, a,u,q,v);
}
// ----------------------------------------------------------------------------------------
template <
......@@ -1090,7 +767,6 @@ convergence:
const matrix_exp<EXP>& m
)
{
using namespace nric;
// you can't invert a non-square matrix
COMPILE_TIME_ASSERT(matrix_exp<EXP>::NR == matrix_exp<EXP>::NC ||
matrix_exp<EXP>::NR == 0 ||
......@@ -1557,13 +1233,21 @@ convergence:
return;
}
#endif
v.set_size(m.nc(),m.nc());
u = m;
if (m.nr() >= m.nc())
{
svd4(SVD_SKINNY_U,true, m, u,w,v);
}
else
{
svd4(SVD_FULL_U,true, trans(m), v,w,u);
w.set_size(m.nc(),1);
matrix<T,matrix_exp<EXP>::NC,1,MM1> rv1(m.nc(),1);
nric::svdcmp(u,w,v,rv1);
// if u isn't the size we want then pad it (and v) with zeros
if (u.nc() < m.nc())
{
w = join_cols(w, zeros_matrix<T>(m.nc()-u.nc(),1));
u = join_rows(u, zeros_matrix<T>(u.nr(), m.nc()-u.nc()));
}
}
}
// ----------------------------------------------------------------------------------------
......@@ -1692,7 +1376,6 @@ convergence:
const matrix_exp<EXP>& m
)
{
using namespace nric;
COMPILE_TIME_ASSERT(matrix_exp<EXP>::NR == matrix_exp<EXP>::NC ||
matrix_exp<EXP>::NR == 0 ||
matrix_exp<EXP>::NC == 0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment