Commit 1921e504 authored by Davis King's avatar Davis King

Pushed all the work variables into the LAPACK binding functions.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403823
parent 16f452a6
...@@ -203,8 +203,8 @@ namespace dlib ...@@ -203,8 +203,8 @@ namespace dlib
template < template <
typename T, typename T,
long NR1, long NR2, long NR3, long NR4, long NR5, long NR1, long NR2, long NR3, long NR4,
long NC1, long NC2, long NC3, long NC4, long NC5, long NC1, long NC2, long NC3, long NC4,
typename MM, typename MM,
typename layout typename layout
> >
...@@ -213,10 +213,11 @@ namespace dlib ...@@ -213,10 +213,11 @@ namespace dlib
matrix<T,NR1,NC1,MM,column_major_layout>& a, matrix<T,NR1,NC1,MM,column_major_layout>& a,
matrix<T,NR2,NC2,MM,layout>& wr, matrix<T,NR2,NC2,MM,layout>& wr,
matrix<T,NR3,NC3,MM,layout>& wi, matrix<T,NR3,NC3,MM,layout>& wi,
matrix<T,NR4,NC4,MM,column_major_layout>& vs, matrix<T,NR4,NC4,MM,column_major_layout>& vs
matrix<T,NR5,NC5,MM,column_major_layout>& work
) )
{ {
matrix<T,0,1,MM,column_major_layout> work;
const long n = a.nr(); const long n = a.nr();
wr.set_size(n,1); wr.set_size(n,1);
......
...@@ -164,8 +164,8 @@ namespace dlib ...@@ -164,8 +164,8 @@ namespace dlib
template < template <
typename T, typename T,
long NR1, long NR2, long NR3, long NR4, long NR5, long NR6, long NR1, long NR2, long NR3, long NR4, long NR5,
long NC1, long NC2, long NC3, long NC4, long NC5, long NC6, long NC1, long NC2, long NC3, long NC4, long NC5,
typename MM, typename MM,
typename layout typename layout
> >
...@@ -176,10 +176,11 @@ namespace dlib ...@@ -176,10 +176,11 @@ namespace dlib
matrix<T,NR2,NC2,MM,layout>& wr, matrix<T,NR2,NC2,MM,layout>& wr,
matrix<T,NR3,NC3,MM,layout>& wi, matrix<T,NR3,NC3,MM,layout>& wi,
matrix<T,NR4,NC4,MM,column_major_layout>& vl, matrix<T,NR4,NC4,MM,column_major_layout>& vl,
matrix<T,NR5,NC5,MM,column_major_layout>& vr, matrix<T,NR5,NC5,MM,column_major_layout>& vr
matrix<T,NR6,NC6,MM,column_major_layout>& work
) )
{ {
matrix<T,0,1,MM,column_major_layout> work;
const long n = a.nr(); const long n = a.nr();
wr.set_size(n,1); wr.set_size(n,1);
......
...@@ -121,16 +121,17 @@ namespace dlib ...@@ -121,16 +121,17 @@ namespace dlib
template < template <
typename T, typename T,
long NR1, long NR2, long NR3, long NR1, long NR2,
long NC1, long NC2, long NC3, long NC1, long NC2,
typename MM typename MM
> >
int geqrf ( int geqrf (
matrix<T,NR1,NC1,MM,column_major_layout>& a, matrix<T,NR1,NC1,MM,column_major_layout>& a,
matrix<T,NR2,NC2,MM,column_major_layout>& tau, matrix<T,NR2,NC2,MM,column_major_layout>& tau
matrix<T,NR3,NC3,MM,column_major_layout>& work
) )
{ {
matrix<T,0,1,MM,column_major_layout> work;
tau.set_size(std::min(a.nr(), a.nc()), 1); tau.set_size(std::min(a.nr(), a.nc()), 1);
// figure out how big the workspace needs to be. // figure out how big the workspace needs to be.
......
...@@ -189,8 +189,8 @@ namespace dlib ...@@ -189,8 +189,8 @@ namespace dlib
template < template <
typename T, typename T,
long NR1, long NR2, long NR3, long NR4, long NR5, long NR6, long NR1, long NR2, long NR3, long NR4,
long NC1, long NC2, long NC3, long NC4, long NC5, long NC6, long NC1, long NC2, long NC3, long NC4,
typename MM typename MM
> >
int gesdd ( int gesdd (
...@@ -198,11 +198,12 @@ namespace dlib ...@@ -198,11 +198,12 @@ namespace dlib
matrix<T,NR1,NC1,MM,column_major_layout>& a, matrix<T,NR1,NC1,MM,column_major_layout>& a,
matrix<T,NR2,NC2,MM,column_major_layout>& s, matrix<T,NR2,NC2,MM,column_major_layout>& s,
matrix<T,NR3,NC3,MM,column_major_layout>& u, matrix<T,NR3,NC3,MM,column_major_layout>& u,
matrix<T,NR4,NC4,MM,column_major_layout>& vt, matrix<T,NR4,NC4,MM,column_major_layout>& vt
matrix<T,NR5,NC5,MM,column_major_layout>& work,
matrix<integer,NR6,NC6,MM,column_major_layout>& iwork
) )
{ {
matrix<T,0,1,MM,column_major_layout> work;
matrix<integer,0,1,MM,column_major_layout> iwork;
const long m = a.nr(); const long m = a.nr();
const long n = a.nc(); const long n = a.nc();
s.set_size(std::min(m,n), 1); s.set_size(std::min(m,n), 1);
...@@ -251,8 +252,8 @@ namespace dlib ...@@ -251,8 +252,8 @@ namespace dlib
template < template <
typename T, typename T,
long NR1, long NR2, long NR3, long NR4, long NR5, long NR6, long NR1, long NR2, long NR3, long NR4,
long NC1, long NC2, long NC3, long NC4, long NC5, long NC6, long NC1, long NC2, long NC3, long NC4,
typename MM typename MM
> >
int gesdd ( int gesdd (
...@@ -260,11 +261,12 @@ namespace dlib ...@@ -260,11 +261,12 @@ namespace dlib
matrix<T,NR1,NC1,MM,row_major_layout>& a, matrix<T,NR1,NC1,MM,row_major_layout>& a,
matrix<T,NR2,NC2,MM,row_major_layout>& s, matrix<T,NR2,NC2,MM,row_major_layout>& s,
matrix<T,NR3,NC3,MM,row_major_layout>& u_, matrix<T,NR3,NC3,MM,row_major_layout>& u_,
matrix<T,NR4,NC4,MM,row_major_layout>& vt_, matrix<T,NR4,NC4,MM,row_major_layout>& vt_
matrix<T,NR5,NC5,MM,row_major_layout>& work,
matrix<integer,NR6,NC6,MM,row_major_layout>& iwork
) )
{ {
matrix<T,0,1,MM,row_major_layout> work;
matrix<integer,0,1,MM,row_major_layout> iwork;
// Row major order matrices are transposed from LAPACK's point of view. // Row major order matrices are transposed from LAPACK's point of view.
matrix<T,NR3,NC3,MM,row_major_layout>& u = vt_; matrix<T,NR3,NC3,MM,row_major_layout>& u = vt_;
matrix<T,NR4,NC4,MM,row_major_layout>& vt = u_; matrix<T,NR4,NC4,MM,row_major_layout>& vt = u_;
......
...@@ -181,8 +181,8 @@ namespace dlib ...@@ -181,8 +181,8 @@ namespace dlib
template < template <
typename T, typename T,
long NR1, long NR2, long NR3, long NR4, long NR5, long NR1, long NR2, long NR3, long NR4,
long NC1, long NC2, long NC3, long NC4, long NC5, long NC1, long NC2, long NC3, long NC4,
typename MM typename MM
> >
int gesvd ( int gesvd (
...@@ -191,10 +191,11 @@ namespace dlib ...@@ -191,10 +191,11 @@ namespace dlib
matrix<T,NR1,NC1,MM,column_major_layout>& a, matrix<T,NR1,NC1,MM,column_major_layout>& a,
matrix<T,NR2,NC2,MM,column_major_layout>& s, matrix<T,NR2,NC2,MM,column_major_layout>& s,
matrix<T,NR3,NC3,MM,column_major_layout>& u, matrix<T,NR3,NC3,MM,column_major_layout>& u,
matrix<T,NR4,NC4,MM,column_major_layout>& vt, matrix<T,NR4,NC4,MM,column_major_layout>& vt
matrix<T,NR5,NC5,MM,column_major_layout>& work
) )
{ {
matrix<T,0,1,MM,column_major_layout> work;
const long m = a.nr(); const long m = a.nr();
const long n = a.nc(); const long n = a.nc();
s.set_size(std::min(m,n), 1); s.set_size(std::min(m,n), 1);
...@@ -237,8 +238,8 @@ namespace dlib ...@@ -237,8 +238,8 @@ namespace dlib
template < template <
typename T, typename T,
long NR1, long NR2, long NR3, long NR4, long NR5, long NR1, long NR2, long NR3, long NR4,
long NC1, long NC2, long NC3, long NC4, long NC5, long NC1, long NC2, long NC3, long NC4,
typename MM typename MM
> >
int gesvd ( int gesvd (
...@@ -247,10 +248,11 @@ namespace dlib ...@@ -247,10 +248,11 @@ namespace dlib
matrix<T,NR1,NC1,MM,row_major_layout>& a, matrix<T,NR1,NC1,MM,row_major_layout>& a,
matrix<T,NR2,NC2,MM,row_major_layout>& s, matrix<T,NR2,NC2,MM,row_major_layout>& s,
matrix<T,NR3,NC3,MM,row_major_layout>& u_, matrix<T,NR3,NC3,MM,row_major_layout>& u_,
matrix<T,NR4,NC4,MM,row_major_layout>& vt_, matrix<T,NR4,NC4,MM,row_major_layout>& vt_
matrix<T,NR5,NC5,MM,row_major_layout>& work
) )
{ {
matrix<T,0,1,MM,row_major_layout> work;
// Row major order matrices are transposed from LAPACK's point of view. // Row major order matrices are transposed from LAPACK's point of view.
matrix<T,NR3,NC3,MM,row_major_layout>& u = vt_; matrix<T,NR3,NC3,MM,row_major_layout>& u = vt_;
matrix<T,NR4,NC4,MM,row_major_layout>& vt = u_; matrix<T,NR4,NC4,MM,row_major_layout>& vt = u_;
......
...@@ -117,18 +117,19 @@ namespace dlib ...@@ -117,18 +117,19 @@ namespace dlib
template < template <
typename T, typename T,
long NR1, long NR2, long NR3, long NR1, long NR2,
long NC1, long NC2, long NC3, long NC1, long NC2,
typename MM typename MM
> >
int syev ( int syev (
const char jobz, const char jobz,
const char uplo, const char uplo,
matrix<T,NR1,NC1,MM,column_major_layout>& a, matrix<T,NR1,NC1,MM,column_major_layout>& a,
matrix<T,NR2,NC2,MM,column_major_layout>& w, matrix<T,NR2,NC2,MM,column_major_layout>& w
matrix<T,NR3,NC3,MM,column_major_layout>& work
) )
{ {
matrix<T,0,1,MM,column_major_layout> work;
const long n = a.nr(); const long n = a.nr();
w.set_size(n,1); w.set_size(n,1);
...@@ -156,18 +157,19 @@ namespace dlib ...@@ -156,18 +157,19 @@ namespace dlib
template < template <
typename T, typename T,
long NR1, long NR2, long NR3, long NR1, long NR2,
long NC1, long NC2, long NC3, long NC1, long NC2,
typename MM typename MM
> >
int syev ( int syev (
char jobz, char jobz,
char uplo, char uplo,
matrix<T,NR1,NC1,MM,row_major_layout>& a, matrix<T,NR1,NC1,MM,row_major_layout>& a,
matrix<T,NR2,NC2,MM,row_major_layout>& w, matrix<T,NR2,NC2,MM,row_major_layout>& w
matrix<T,NR3,NC3,MM,row_major_layout>& work
) )
{ {
matrix<T,0,1,MM,row_major_layout> work;
if (uplo == 'L') if (uplo == 'L')
uplo = 'U'; uplo = 'U';
else else
......
...@@ -291,8 +291,8 @@ namespace dlib ...@@ -291,8 +291,8 @@ namespace dlib
template < template <
typename T, typename T,
long NR1, long NR2, long NR3, long NR4, long NR5, long NR6, long NR1, long NR2, long NR3, long NR4,
long NC1, long NC2, long NC3, long NC4, long NC5, long NC6, long NC1, long NC2, long NC3, long NC4,
typename MM typename MM
> >
int syevr ( int syevr (
...@@ -308,11 +308,12 @@ namespace dlib ...@@ -308,11 +308,12 @@ namespace dlib
integer& num_eigenvalues_found, integer& num_eigenvalues_found,
matrix<T,NR2,NC2,MM,column_major_layout>& w, matrix<T,NR2,NC2,MM,column_major_layout>& w,
matrix<T,NR3,NC3,MM,column_major_layout>& z, matrix<T,NR3,NC3,MM,column_major_layout>& z,
matrix<integer,NR4,NC4,MM,column_major_layout>& isuppz, matrix<integer,NR4,NC4,MM,column_major_layout>& isuppz
matrix<T,NR5,NC5,MM,column_major_layout>& work,
matrix<integer,NR6,NC6,MM,column_major_layout>& iwork
) )
{ {
matrix<T,0,1,MM,column_major_layout> work;
matrix<integer,0,1,MM,column_major_layout> iwork;
const long n = a.nr(); const long n = a.nr();
w.set_size(n,1); w.set_size(n,1);
...@@ -358,8 +359,8 @@ namespace dlib ...@@ -358,8 +359,8 @@ namespace dlib
template < template <
typename T, typename T,
long NR1, long NR2, long NR3, long NR4, long NR5, long NR6, long NR1, long NR2, long NR3, long NR4,
long NC1, long NC2, long NC3, long NC4, long NC5, long NC6, long NC1, long NC2, long NC3, long NC4,
typename MM typename MM
> >
int syevr ( int syevr (
...@@ -375,11 +376,12 @@ namespace dlib ...@@ -375,11 +376,12 @@ namespace dlib
integer& num_eigenvalues_found, integer& num_eigenvalues_found,
matrix<T,NR2,NC2,MM,row_major_layout>& w, matrix<T,NR2,NC2,MM,row_major_layout>& w,
matrix<T,NR3,NC3,MM,row_major_layout>& z, matrix<T,NR3,NC3,MM,row_major_layout>& z,
matrix<integer,NR4,NC4,MM,row_major_layout>& isuppz, matrix<integer,NR4,NC4,MM,row_major_layout>& isuppz
matrix<T,NR5,NC5,MM,row_major_layout>& work,
matrix<integer,NR6,NC6,MM,row_major_layout>& iwork
) )
{ {
matrix<T,0,1,MM,row_major_layout> work;
matrix<integer,0,1,MM,row_major_layout> iwork;
if (uplo == 'L') if (uplo == 'L')
uplo = 'U'; uplo = 'U';
else else
......
...@@ -178,11 +178,10 @@ namespace dlib ...@@ -178,11 +178,10 @@ namespace dlib
V = A; V = A;
#ifdef DLIB_USE_LAPACK #ifdef DLIB_USE_LAPACK
matrix<type,0,1,mem_manager_type, layout_type> work;
e = 0; e = 0;
// I would use syevr but the last time I checked there was a bug in the // I would use syevr but the last time I checked there was a bug in the
// Intel MKL's implementation of syevr. // Intel MKL's implementation of syevr.
lapack::syev('V', 'L', V, d, work); lapack::syev('V', 'L', V, d);
#else #else
// Tridiagonalize. // Tridiagonalize.
tred2(); tred2();
...@@ -196,9 +195,9 @@ namespace dlib ...@@ -196,9 +195,9 @@ namespace dlib
{ {
#ifdef DLIB_USE_LAPACK #ifdef DLIB_USE_LAPACK
matrix<type,0,0,mem_manager_type, column_major_layout> temp, vl, vr, work; matrix<type,0,0,mem_manager_type, column_major_layout> temp, vl, vr;
temp = A; temp = A;
lapack::geev('N', 'V', temp, d, e, vl, vr, work); lapack::geev('N', 'V', temp, d, e, vl, vr);
V = vr; V = vr;
#else #else
H = A; H = A;
...@@ -246,11 +245,10 @@ namespace dlib ...@@ -246,11 +245,10 @@ namespace dlib
V = A; V = A;
#ifdef DLIB_USE_LAPACK #ifdef DLIB_USE_LAPACK
matrix<type,0,1,mem_manager_type, layout_type> work;
e = 0; e = 0;
// I would use syevr but the last time I checked there was a bug in the // I would use syevr but the last time I checked there was a bug in the
// Intel MKL's implementation of syevr. // Intel MKL's implementation of syevr.
lapack::syev('V', 'L', V, d, work); lapack::syev('V', 'L', V, d);
#else #else
// Tridiagonalize. // Tridiagonalize.
tred2(); tred2();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment