Commit ecd51c42 authored by Davis King's avatar Davis King

Setup the SVD routines to use LAPACK when available. I also changed the svd functions

so that you can't supply output matrices which use both column and row major layouts.
Now all the output matrices need to use the same memory layout.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403826
parent 76981f1e
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#ifdef DLIB_USE_LAPACK #ifdef DLIB_USE_LAPACK
#include "lapack/potrf.h" #include "lapack/potrf.h"
#include "lapack/gesdd.h"
#include "lapack/gesvd.h"
#endif #endif
namespace dlib namespace dlib
...@@ -567,17 +569,15 @@ namespace dlib ...@@ -567,17 +569,15 @@ namespace dlib
typename MM1, typename MM1,
typename MM2, typename MM2,
typename MM3, typename MM3,
typename L1, typename L1
typename L2,
typename L3
> >
long svd2 ( long svd2 (
bool withu, bool withu,
bool withv, bool withv,
const matrix_exp<EXP>& a, const matrix_exp<EXP>& a,
matrix<typename EXP::type,uM,uM,MM1,L1>& u, matrix<typename EXP::type,uM,uM,MM1,L1>& u,
matrix<typename EXP::type,qN,qX,MM2,L2>& q, matrix<typename EXP::type,qN,qX,MM2,L1>& q,
matrix<typename EXP::type,vN,vN,MM3,L3>& v matrix<typename EXP::type,vN,vN,MM3,L1>& v
) )
{ {
/* /*
...@@ -624,6 +624,34 @@ namespace dlib ...@@ -624,6 +624,34 @@ namespace dlib
typedef typename EXP::type T; typedef typename EXP::type T;
#ifdef DLIB_USE_LAPACK
matrix<typename EXP::type,0,0,MM1,L1> temp(a);
char jobu = 'A';
char jobvt = 'A';
if (withu == false)
jobu = 'N';
if (withv == false)
jobvt = 'N';
int info;
if (withu == withv)
{
info = lapack::gesdd(jobu, temp, q, u, v);
}
else
{
info = lapack::gesvd(jobu, jobvt, temp, q, u, v);
}
// pad q with zeros if it isn't the length we want
if (q.nr() < a.nc())
q = join_cols(q, zeros_matrix<T>(a.nc()-q.nr(),1));
v = trans(v);
return info;
#else
using std::abs; using std::abs;
using std::sqrt; using std::sqrt;
...@@ -944,6 +972,7 @@ convergence: ...@@ -944,6 +972,7 @@ convergence:
} /* end k */ } /* end k */
return retval; return retval;
#endif
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -1354,15 +1383,13 @@ convergence: ...@@ -1354,15 +1383,13 @@ convergence:
typename MM1, typename MM1,
typename MM2, typename MM2,
typename MM3, typename MM3,
typename L1, typename L1
typename L2,
typename L3
> >
inline void svd3 ( inline void svd3 (
const matrix_exp<EXP>& m, const matrix_exp<EXP>& m,
matrix<typename matrix_exp<EXP>::type, uNR, uNC,MM1,L1>& u, matrix<typename matrix_exp<EXP>::type, uNR, uNC,MM1,L1>& u,
matrix<typename matrix_exp<EXP>::type, wN, wX,MM2,L2>& w, matrix<typename matrix_exp<EXP>::type, wN, wX,MM2,L1>& w,
matrix<typename matrix_exp<EXP>::type, vN, vN,MM3,L3>& v matrix<typename matrix_exp<EXP>::type, vN, vN,MM3,L1>& v
) )
{ {
typedef typename matrix_exp<EXP>::type T; typedef typename matrix_exp<EXP>::type T;
...@@ -1376,14 +1403,27 @@ convergence: ...@@ -1376,14 +1403,27 @@ convergence:
COMPILE_TIME_ASSERT(NC == 0 || vN == 0 || NC == vN); COMPILE_TIME_ASSERT(NC == 0 || vN == 0 || NC == vN);
COMPILE_TIME_ASSERT(wX == 0 || wX == 1); COMPILE_TIME_ASSERT(wX == 0 || wX == 1);
typedef typename matrix_exp<EXP>::type T;
#ifdef DLIB_USE_LAPACK
matrix<typename matrix_exp<EXP>::type, uNR, uNC,MM1,L1> temp(m);
lapack::gesvd('S','A', temp, w, u, v);
v = trans(v);
// if u isn't the size we want then pad it (and v) with zeros
if (u.nc() < m.nc())
{
w = join_cols(w, zeros_matrix<T>(m.nc()-u.nc(),1));
u = join_rows(u, zeros_matrix<T>(u.nr(), m.nc()-u.nc()));
}
#else
v.set_size(m.nc(),m.nc()); v.set_size(m.nc(),m.nc());
typedef typename matrix_exp<EXP>::type T;
u = m; u = m;
w.set_size(m.nc(),1); w.set_size(m.nc(),1);
matrix<T,matrix_exp<EXP>::NC,1,MM1> rv1(m.nc(),1); matrix<T,matrix_exp<EXP>::NC,1,MM1> rv1(m.nc(),1);
nric::svdcmp(u,w,v,rv1); nric::svdcmp(u,w,v,rv1);
#endif
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -1397,15 +1437,13 @@ convergence: ...@@ -1397,15 +1437,13 @@ convergence:
typename MM1, typename MM1,
typename MM2, typename MM2,
typename MM3, typename MM3,
typename L1, typename L1
typename L2,
typename L3
> >
inline void svd ( inline void svd (
const matrix_exp<EXP>& m, const matrix_exp<EXP>& m,
matrix<typename matrix_exp<EXP>::type, uNR, uNC,MM1,L1>& u, matrix<typename matrix_exp<EXP>::type, uNR, uNC,MM1,L1>& u,
matrix<typename matrix_exp<EXP>::type, wN, wN,MM2,L2>& w, matrix<typename matrix_exp<EXP>::type, wN, wN,MM2,L1>& w,
matrix<typename matrix_exp<EXP>::type, vN, vN,MM3,L3>& v matrix<typename matrix_exp<EXP>::type, vN, vN,MM3,L1>& v
) )
{ {
typedef typename matrix_exp<EXP>::type T; typedef typename matrix_exp<EXP>::type T;
......
...@@ -62,6 +62,8 @@ namespace dlib ...@@ -62,6 +62,8 @@ namespace dlib
- #w.nc() == m.nc() - #w.nc() == m.nc()
- #v.nr() == m.nc() - #v.nr() == m.nc()
- #v.nc() == m.nc() - #v.nc() == m.nc()
- if DLIB_USE_LAPACK is #defined then the xGESVD routine
from LAPACK is used to compute the SVD.
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -98,6 +100,11 @@ namespace dlib ...@@ -98,6 +100,11 @@ namespace dlib
output state is undefined. output state is undefined.
- returns an error code of 0, if no errors and 'k' if we fail to - returns an error code of 0, if no errors and 'k' if we fail to
converge at the 'kth' singular value. converge at the 'kth' singular value.
- if (DLIB_USE_LAPACK is #defined) then
- if (withu == withv) then
- the xGESDD routine from LAPACK is used to compute the SVD.
- else
- the xGESVD routine from LAPACK is used to compute the SVD.
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -122,6 +129,8 @@ namespace dlib ...@@ -122,6 +129,8 @@ namespace dlib
- #w.nc() == 1 - #w.nc() == 1
- #v.nr() == m.nc() - #v.nr() == m.nc()
- #v.nc() == m.nc() - #v.nc() == m.nc()
- if DLIB_USE_LAPACK is #defined then the xGESVD routine
from LAPACK is used to compute the SVD.
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment