Commit 438ec254 authored by Davis King's avatar Davis King

Setup the LU decomposition code to use LAPACK when available. I also removed the older

version from numerical recipes and made everything depend on the lu_decomposition object
instead.  Finally, I added in a triangular solver that uses BLAS when available and made
the lu_decomposition object us it.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403833
parent 9de0a49c
#ifndef DLIB_CBLAS_CONSTAnTS_H__
#define DLIB_CBLAS_CONSTAnTS_H__
namespace dlib
{
namespace blas_bindings
{
enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};
}
}
#endif // DLIB_CBLAS_CONSTAnTS_H__
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_LAPACk_GETRF_H__ #ifndef DLIB_LAPACk_GETRF_H__
#define DLIB_LAPACk_GETRF_H__ #define DLIB_LAPACk_GETRF_H__
...@@ -106,28 +108,16 @@ namespace dlib ...@@ -106,28 +108,16 @@ namespace dlib
> >
int getrf ( int getrf (
matrix<T,NR1,NC1,MM,column_major_layout>& a, matrix<T,NR1,NC1,MM,column_major_layout>& a,
matrix<long,NR2,NC2,MM,layout>& ipiv matrix<integer,NR2,NC2,MM,layout>& ipiv
) )
{ {
const long m = a.nr(); const long m = a.nr();
const long n = a.nc(); const long n = a.nc();
matrix<integer,NR2,NC2,MM,column_major_layout> ipiv_temp(std::min(m,n), 1); ipiv.set_size(std::min(m,n), 1);
// compute the actual decomposition // compute the actual decomposition
int info = binding::getrf(m, n, &a(0,0), a.nr(), &ipiv_temp(0,0)); return binding::getrf(m, n, &a(0,0), a.nr(), &ipiv(0,0));
// Turn the P vector into a more useful form. This way we will have the identity
// a == rowm(L*U, ipiv). The permutation vector that comes out of LAPACK is somewhat
// different.
ipiv = trans(range(0, a.nr()-1));
for (long i = ipiv_temp.size()-1; i >= 0; --i)
{
// -1 because FORTRAN is indexed starting with 1 instead of 0
std::swap(ipiv(i), ipiv(ipiv_temp(i)-1));
}
return info;
} }
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "matrix_assign.h" #include "matrix_assign.h"
#include "matrix_conj_trans.h" #include "matrix_conj_trans.h"
#include "cblas_constants.h"
//#include <iostream> //#include <iostream>
//using namespace std; //using namespace std;
...@@ -41,9 +42,6 @@ namespace dlib ...@@ -41,9 +42,6 @@ namespace dlib
{ {
// Here we declare the prototypes for the CBLAS calls used by the BLAS bindings below // Here we declare the prototypes for the CBLAS calls used by the BLAS bindings below
enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const float alpha, const float *A, const int K, const float alpha, const float *A,
......
...@@ -393,166 +393,6 @@ namespace dlib ...@@ -393,166 +393,6 @@ namespace dlib
return true; return true;
} }
template <
typename T,
long N,
long NX,
typename MM1,
typename MM2,
typename MM3,
typename L1,
typename L2,
typename L3
>
bool ludcmp (
matrix<T,N,N,MM1,L1>& a,
matrix<long,N,NX,MM2,L2>& indx,
T& d,
matrix<T,N,NX,MM3,L3>& vv
)
/*!
( this function is derived from the one in numerical recipes in C chapter 2.3)
ensures
- #a == both the L and U matrices
- #indx == the permutation vector (see numerical recipes in C)
- #d == some other thing (see numerical recipes in C)
- #vv == some undefined value. this is just used for scratch space
- if (the matrix is singular and we can't do anything) then
- returns false
- else
- returns true
!*/
{
DLIB_ASSERT(indx.nc() == 1,"in dlib::nric::ludcmp() the indx matrix must be a column vector");
DLIB_ASSERT(vv.nc() == 1,"in dlib::nric::ludcmp() the vv matrix must be a column vector");
const long n = a.nr();
long imax = 0;
T big, dum, sum, temp;
d = 1.0;
for (long i = 0; i < n; ++i)
{
big = 0;
for (long j = 0; j < n; ++j)
{
if ((temp=std::abs(a(i,j))) > big)
big = temp;
}
if (big == 0.0)
{
return false;
}
vv(i) = 1/big;
}
for (long j = 0; j < n; ++j)
{
for (long i = 0; i < j; ++i)
{
sum = a(i,j);
for (long k = 0; k < i; ++k)
sum -= a(i,k)*a(k,j);
a(i,j) = sum;
}
big = 0;
for (long i = j; i < n; ++i)
{
sum = a(i,j);
for (long k = 0; k < j; ++k)
sum -= a(i,k)*a(k,j);
a(i,j) = sum;
if ( (dum=vv(i)*std::abs(sum)) >= big)
{
big = dum;
imax = i;
}
}
if (j != imax)
{
for (long k = 0; k < n; ++k)
{
dum = a(imax,k);
a(imax,k) = a(j,k);
a(j,k) = dum;
}
d = -d;
vv(imax) = vv(j);
}
indx(j) = imax;
if (j < n-1)
{
if (a(j,j) == 0)
return false;
dum = 1/a(j,j);
for (long i = j+1; i < n; ++i)
a(i,j) *= dum;
}
}
return true;
}
// ----------------------------------------------------------------------------------------
template <
typename T,
long N,
long NX,
typename MM1,
typename MM2,
typename MM3,
typename L1,
typename L2,
typename L3
>
void lubksb (
const matrix<T,N,N,MM1,L1>& a,
const matrix<long,N,NX,MM2,L2>& indx,
matrix<T,N,NX,MM3,L3>& b
)
/*!
( this function is derived from the one in numerical recipes in C chapter 2.3)
requires
- a == the LU decomposition you get from ludcmp()
- indx == the indx term you get out of ludcmp()
- b == the right hand side vector from the expression a*x = b
ensures
- #b == the solution vector x from the expression a*x = b
(basically, this function solves for x given b and a)
!*/
{
DLIB_ASSERT(indx.nc() == 1,"in dlib::nric::lubksb() the indx matrix must be a column vector");
DLIB_ASSERT(b.nc() == 1,"in dlib::nric::lubksb() the b matrix must be a column vector");
const long n = a.nr();
long i, ii = -1, ip, j;
T sum;
for (i = 0; i < n; ++i)
{
ip = indx(i);
sum=b(ip);
b(ip) = b(i);
if (ii != -1)
{
for (j = ii; j < i; ++j)
sum -= a(i,j)*b(j);
}
else if (sum)
{
ii = i;
}
b(i) = sum;
}
for (i = n-1; i >= 0; --i)
{
sum = b(i);
for (j = i+1; j < n; ++j)
sum -= a(i,j)*b(j);
b(i) = sum/a(i,i);
}
}
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
} }
...@@ -1001,31 +841,8 @@ convergence: ...@@ -1001,31 +841,8 @@ convergence:
); );
typedef typename matrix_exp<EXP>::type type; typedef typename matrix_exp<EXP>::type type;
matrix<type, N, N,MM> a(m), y(m.nr(),m.nr()); lu_decomposition<EXP> lu(m);
matrix<long,N,1,MM> indx(m.nr(),1); return lu.solve(identity_matrix<type>(m.nr()));
matrix<type,N,1,MM> col(m.nr(),1);
matrix<type,N,1,MM> vv(m.nr(),1);
type d;
long i, j;
if (ludcmp(a,indx,d,vv))
{
for (j = 0; j < m.nr(); ++j)
{
for (i = 0; i < m.nr(); ++i)
col(i) = 0;
col(j) = 1;
lubksb(a,indx,col);
for (i = 0; i < m.nr(); ++i)
y(i,j) = col(i);
}
}
else
{
// m is singular so lets just set y equal to m just so that
// it has some value
y = m;
}
return y;
} }
}; };
...@@ -1502,17 +1319,7 @@ convergence: ...@@ -1502,17 +1319,7 @@ convergence:
typedef typename matrix_exp<EXP>::type type; typedef typename matrix_exp<EXP>::type type;
typedef typename matrix_exp<EXP>::mem_manager_type MM; typedef typename matrix_exp<EXP>::mem_manager_type MM;
matrix<type, N, N,MM> lu(m); return lu_decomposition<EXP>(m).det();
matrix<long,N,1,MM> indx(m.nr(),1);
matrix<type,N,1,MM> vv(m.nr(),1);
type d;
if (ludcmp(lu,indx,d,vv) == false)
{
// the matrix is singular so its det is 0
return 0;
}
return prod(diag(lu))*d;
} }
}; };
......
...@@ -255,6 +255,9 @@ namespace dlib ...@@ -255,6 +255,9 @@ namespace dlib
LU decomposition is in the solution of square systems of simultaneous LU decomposition is in the solution of square systems of simultaneous
linear equations. This will fail if is_singular() returns true (or linear equations. This will fail if is_singular() returns true (or
if A is very nearly singular). if A is very nearly singular).
If DLIB_USE_LAPACK is defined then the LAPACK routine xGETRF
is used to compute the LU decomposition.
!*/ !*/
public: public:
......
...@@ -8,8 +8,14 @@ ...@@ -8,8 +8,14 @@
#include "matrix.h" #include "matrix.h"
#include "matrix_utilities.h" #include "matrix_utilities.h"
#include "matrix_subexp.h" #include "matrix_subexp.h"
#include "matrix_trsm.h"
#include <algorithm> #include <algorithm>
#ifdef DLIB_USE_LAPACK
#include "lapack/getrf.h"
#endif
namespace dlib namespace dlib
{ {
...@@ -72,7 +78,7 @@ namespace dlib ...@@ -72,7 +78,7 @@ namespace dlib
private: private:
/* Array for internal storage of decomposition. */ /* Array for internal storage of decomposition. */
matrix_type LU; matrix<type,0,0,mem_manager_type,column_major_layout> LU;
long m, n, pivsign; long m, n, pivsign;
pivot_column_vector_type piv; pivot_column_vector_type piv;
...@@ -108,6 +114,28 @@ namespace dlib ...@@ -108,6 +114,28 @@ namespace dlib
<< "\n\tthis: " << this << "\n\tthis: " << this
); );
#ifdef DLIB_USE_LAPACK
matrix<lapack::integer,0,1,mem_manager_type,layout_type> piv_temp;
lapack::getrf(LU, piv_temp);
pivsign = 1;
// Turn the piv_temp vector into a more useful form. This way we will have the identity
// rowm(A,piv) == L*U. The permutation vector that comes out of LAPACK is somewhat
// different.
piv = trans(range(0,m-1));
for (long i = 0; i < piv_temp.size(); ++i)
{
// -1 because FORTRAN is indexed starting with 1 instead of 0
if (piv(piv_temp(i)-1) != piv(i))
{
std::swap(piv(i), piv(piv_temp(i)-1));
pivsign = -pivsign;
}
}
#else
// Use a "left-looking", dot-product, Crout/Doolittle algorithm. // Use a "left-looking", dot-product, Crout/Doolittle algorithm.
...@@ -170,6 +198,8 @@ namespace dlib ...@@ -170,6 +198,8 @@ namespace dlib
} }
} }
} }
#endif
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -311,67 +341,15 @@ namespace dlib ...@@ -311,67 +341,15 @@ namespace dlib
<< "\n\tthis: " << this << "\n\tthis: " << this
); );
const long nx = B.nc(); // Copy right hand side with pivoting
// if there are multiple columns in B matrix<type,0,0,mem_manager_type,column_major_layout> X(rowm(B, piv));
if (nx > 1)
{
// Copy right hand side with pivoting
matrix_type X(rowm(B, piv));
// Solve L*Y = B(piv,:) using namespace blas_bindings;
for (long k = 0; k < n; k++) // Solve L*Y = B(piv,:)
{ triangular_solver(CblasLeft, CblasLower, CblasNoTrans, CblasUnit, LU, X);
for (long i = k+1; i < n; i++) // Solve U*X = Y;
{ triangular_solver(CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, LU, X);
for (long j = 0; j < nx; j++) return X;
{
X(i,j) -= X(k,j)*LU(i,k);
}
}
}
// Solve U*X = Y;
for (long k = n-1; k >= 0; k--)
{
for (long j = 0; j < nx; j++)
{
X(k,j) /= LU(k,k);
}
for (long i = 0; i < k; i++)
{
for (long j = 0; j < nx; j++)
{
X(i,j) -= X(k,j)*LU(i,k);
}
}
}
return X;
}
else
{
column_vector_type x(rowm(B, piv));
// Solve L*Y = B(piv)
for (long k = 0; k < n; k++)
{
for (long i = k+1; i < n; i++)
{
x(i) -= x(k)*LU(i,k);
}
}
// Solve U*X = Y;
for (long k = n-1; k >= 0; k--)
{
x(k) /= LU(k,k);
for (long i = 0; i < k; i++)
x(i) -= x(k)*LU(i,k);
}
return x;
}
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
// Copyright (C) 2010 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include "lapack/fortran_id.h"
#include "cblas_constants.h"
namespace dlib
{
namespace blas_bindings
{
extern "C"
{
void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
}
// ------------------------------------------------------------------------------------
/* Purpose */
/* ======= */
/* DTRSM solves one of the matrix equations */
/* op( A )*X = alpha*B, or X*op( A ) = alpha*B, */
/* where alpha is a scalar, X and B are m by n matrices, A is a unit, or */
/* non-unit, upper or lower triangular matrix and op( A ) is one of */
/* op( A ) = A or op( A ) = A'. */
/* The matrix X is overwritten on B. */
/* Arguments */
/* ========== */
/* SIDE - CHARACTER*1. */
/* On entry, SIDE specifies whether op( A ) appears on the left */
/* or right of X as follows: */
/* SIDE = 'L' or 'l' op( A )*X = alpha*B. */
/* SIDE = 'R' or 'r' X*op( A ) = alpha*B. */
/* Unchanged on exit. */
/* UPLO - CHARACTER*1. */
/* On entry, UPLO specifies whether the matrix A is an upper or */
/* lower triangular matrix as follows: */
/* UPLO = 'U' or 'u' A is an upper triangular matrix. */
/* UPLO = 'L' or 'l' A is a lower triangular matrix. */
/* Unchanged on exit. */
/* TRANSA - CHARACTER*1. */
/* On entry, TRANSA specifies the form of op( A ) to be used in */
/* the matrix multiplication as follows: */
/* TRANSA = 'N' or 'n' op( A ) = A. */
/* TRANSA = 'T' or 't' op( A ) = A'. */
/* TRANSA = 'C' or 'c' op( A ) = A'. */
/* Unchanged on exit. */
/* DIAG - CHARACTER*1. */
/* On entry, DIAG specifies whether or not A is unit triangular */
/* as follows: */
/* DIAG = 'U' or 'u' A is assumed to be unit triangular. */
/* DIAG = 'N' or 'n' A is not assumed to be unit */
/* triangular. */
/* Unchanged on exit. */
/* M - INTEGER. */
/* On entry, M specifies the number of rows of B. M must be at */
/* least zero. */
/* Unchanged on exit. */
/* N - INTEGER. */
/* On entry, N specifies the number of columns of B. N must be */
/* at least zero. */
/* Unchanged on exit. */
/* ALPHA - DOUBLE PRECISION. */
/* On entry, ALPHA specifies the scalar alpha. When alpha is */
/* zero then A is not referenced and B need not be set before */
/* entry. */
/* Unchanged on exit. */
/* A - DOUBLE PRECISION array of DIMENSION ( LDA, k ), where k is m */
/* when SIDE = 'L' or 'l' and is n when SIDE = 'R' or 'r'. */
/* Before entry with UPLO = 'U' or 'u', the leading k by k */
/* upper triangular part of the array A must contain the upper */
/* triangular matrix and the strictly lower triangular part of */
/* A is not referenced. */
/* Before entry with UPLO = 'L' or 'l', the leading k by k */
/* lower triangular part of the array A must contain the lower */
/* triangular matrix and the strictly upper triangular part of */
/* A is not referenced. */
/* Note that when DIAG = 'U' or 'u', the diagonal elements of */
/* A are not referenced either, but are assumed to be unity. */
/* Unchanged on exit. */
/* LDA - INTEGER. */
/* On entry, LDA specifies the first dimension of A as declared */
/* in the calling (sub) program. When SIDE = 'L' or 'l' then */
/* LDA must be at least max( 1, m ), when SIDE = 'R' or 'r' */
/* then LDA must be at least max( 1, n ). */
/* Unchanged on exit. */
/* B - DOUBLE PRECISION array of DIMENSION ( LDB, n ). */
/* Before entry, the leading m by n part of the array B must */
/* contain the right-hand side matrix B, and on exit is */
/* overwritten by the solution matrix X. */
/* LDB - INTEGER. */
/* On entry, LDB specifies the first dimension of B as declared */
/* in the calling (sub) program. LDB must be at least */
/* max( 1, m ). */
/* Unchanged on exit. */
/* Level 3 Blas routine. */
/* -- Written on 8-February-1989. */
/* Jack Dongarra, Argonne National Laboratory. */
/* Iain Duff, AERE Harwell. */
/* Jeremy Du Croz, Numerical Algorithms Group Ltd. */
/* Sven Hammarling, Numerical Algorithms Group Ltd. */
template <typename T>
void local_trsm(
const enum CBLAS_ORDER Order,
enum CBLAS_SIDE Side,
enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag,
long m,
long n,
T alpha,
const T *a,
long lda,
T *b,
long ldb
)
/*!
This is a copy of the dtrsm routine from the netlib.org BLAS which was run though
f2c and converted into this form for use when a BLAS library is not available.
!*/
{
if (Order == CblasRowMajor)
{
// since row major ordering looks like transposition to FORTRAN we need to flip a
// few things.
if (Side == CblasLeft)
Side = CblasRight;
else
Side = CblasLeft;
if (Uplo == CblasUpper)
Uplo = CblasLower;
else
Uplo = CblasUpper;
std::swap(m,n);
}
/* System generated locals */
long a_dim1, a_offset, b_dim1, b_offset, i__1, i__2, i__3;
/* Local variables */
long i__, j, k, info;
T temp;
bool lside;
long nrowa;
bool upper;
bool nounit;
/* Parameter adjustments */
a_dim1 = lda;
a_offset = 1 + a_dim1;
a -= a_offset;
b_dim1 = ldb;
b_offset = 1 + b_dim1;
b -= b_offset;
/* Function Body */
lside = (Side == CblasLeft);
if (lside)
{
nrowa = m;
} else
{
nrowa = n;
}
nounit = (Diag == CblasNonUnit);
upper = (Uplo == CblasUpper);
info = 0;
if (! lside && ! (Side == CblasRight)) {
info = 1;
} else if (! upper && !(Uplo == CblasLower) ) {
info = 2;
} else if (!(TransA == CblasNoTrans) &&
!(TransA == CblasTrans) &&
!(TransA == CblasConjTrans)) {
info = 3;
} else if (!(Diag == CblasUnit) &&
!(Diag == CblasNonUnit) ) {
info = 4;
} else if (m < 0) {
info = 5;
} else if (n < 0) {
info = 6;
} else if (lda < std::max<long>(1,nrowa)) {
info = 9;
} else if (ldb < std::max<long>(1,m)) {
info = 11;
}
DLIB_CASSERT( info == 0, "Invalid inputs given to local_trsm");
/* Quick return if possible. */
if (m == 0 || n == 0) {
return;
}
/* And when alpha.eq.zero. */
if (alpha == 0.) {
i__1 = n;
for (j = 1; j <= i__1; ++j) {
i__2 = m;
for (i__ = 1; i__ <= i__2; ++i__) {
b[i__ + j * b_dim1] = 0.;
/* L10: */
}
/* L20: */
}
return;
}
/* Start the operations. */
if (lside) {
if (TransA == CblasNoTrans) {
/* Form B := alpha*inv( A )*B. */
if (upper) {
i__1 = n;
for (j = 1; j <= i__1; ++j) {
if (alpha != 1.) {
i__2 = m;
for (i__ = 1; i__ <= i__2; ++i__) {
b[i__ + j * b_dim1] = alpha * b[i__ + j * b_dim1]
;
/* L30: */
}
}
for (k = m; k >= 1; --k) {
if (b[k + j * b_dim1] != 0.) {
if (nounit) {
b[k + j * b_dim1] /= a[k + k * a_dim1];
}
i__2 = k - 1;
for (i__ = 1; i__ <= i__2; ++i__) {
b[i__ + j * b_dim1] -= b[k + j * b_dim1] * a[
i__ + k * a_dim1];
/* L40: */
}
}
/* L50: */
}
/* L60: */
}
} else {
i__1 = n;
for (j = 1; j <= i__1; ++j) {
if (alpha != 1.) {
i__2 = m;
for (i__ = 1; i__ <= i__2; ++i__) {
b[i__ + j * b_dim1] = alpha * b[i__ + j * b_dim1]
;
/* L70: */
}
}
i__2 = m;
for (k = 1; k <= i__2; ++k) {
if (b[k + j * b_dim1] != 0.) {
if (nounit) {
b[k + j * b_dim1] /= a[k + k * a_dim1];
}
i__3 = m;
for (i__ = k + 1; i__ <= i__3; ++i__) {
b[i__ + j * b_dim1] -= b[k + j * b_dim1] * a[
i__ + k * a_dim1];
/* L80: */
}
}
/* L90: */
}
/* L100: */
}
}
} else {
/* Form B := alpha*inv( A' )*B. */
if (upper) {
i__1 = n;
for (j = 1; j <= i__1; ++j) {
i__2 = m;
for (i__ = 1; i__ <= i__2; ++i__) {
temp = alpha * b[i__ + j * b_dim1];
i__3 = i__ - 1;
for (k = 1; k <= i__3; ++k) {
temp -= a[k + i__ * a_dim1] * b[k + j * b_dim1];
/* L110: */
}
if (nounit) {
temp /= a[i__ + i__ * a_dim1];
}
b[i__ + j * b_dim1] = temp;
/* L120: */
}
/* L130: */
}
} else {
i__1 = n;
for (j = 1; j <= i__1; ++j) {
for (i__ = m; i__ >= 1; --i__) {
temp = alpha * b[i__ + j * b_dim1];
i__2 = m;
for (k = i__ + 1; k <= i__2; ++k) {
temp -= a[k + i__ * a_dim1] * b[k + j * b_dim1];
/* L140: */
}
if (nounit) {
temp /= a[i__ + i__ * a_dim1];
}
b[i__ + j * b_dim1] = temp;
/* L150: */
}
/* L160: */
}
}
}
} else {
if (TransA == CblasNoTrans) {
/* Form B := alpha*B*inv( A ). */
if (upper) {
i__1 = n;
for (j = 1; j <= i__1; ++j) {
if (alpha != 1.) {
i__2 = m;
for (i__ = 1; i__ <= i__2; ++i__) {
b[i__ + j * b_dim1] = alpha * b[i__ + j * b_dim1]
;
/* L170: */
}
}
i__2 = j - 1;
for (k = 1; k <= i__2; ++k) {
if (a[k + j * a_dim1] != 0.) {
i__3 = m;
for (i__ = 1; i__ <= i__3; ++i__) {
b[i__ + j * b_dim1] -= a[k + j * a_dim1] * b[
i__ + k * b_dim1];
/* L180: */
}
}
/* L190: */
}
if (nounit) {
temp = 1. / a[j + j * a_dim1];
i__2 = m;
for (i__ = 1; i__ <= i__2; ++i__) {
b[i__ + j * b_dim1] = temp * b[i__ + j * b_dim1];
/* L200: */
}
}
/* L210: */
}
} else {
for (j = n; j >= 1; --j) {
if (alpha != 1.) {
i__1 = m;
for (i__ = 1; i__ <= i__1; ++i__) {
b[i__ + j * b_dim1] = alpha * b[i__ + j * b_dim1]
;
/* L220: */
}
}
i__1 = n;
for (k = j + 1; k <= i__1; ++k) {
if (a[k + j * a_dim1] != 0.) {
i__2 = m;
for (i__ = 1; i__ <= i__2; ++i__) {
b[i__ + j * b_dim1] -= a[k + j * a_dim1] * b[
i__ + k * b_dim1];
/* L230: */
}
}
/* L240: */
}
if (nounit) {
temp = 1. / a[j + j * a_dim1];
i__1 = m;
for (i__ = 1; i__ <= i__1; ++i__) {
b[i__ + j * b_dim1] = temp * b[i__ + j * b_dim1];
/* L250: */
}
}
/* L260: */
}
}
} else {
/* Form B := alpha*B*inv( A' ). */
if (upper) {
for (k = n; k >= 1; --k) {
if (nounit) {
temp = 1. / a[k + k * a_dim1];
i__1 = m;
for (i__ = 1; i__ <= i__1; ++i__) {
b[i__ + k * b_dim1] = temp * b[i__ + k * b_dim1];
/* L270: */
}
}
i__1 = k - 1;
for (j = 1; j <= i__1; ++j) {
if (a[j + k * a_dim1] != 0.) {
temp = a[j + k * a_dim1];
i__2 = m;
for (i__ = 1; i__ <= i__2; ++i__) {
b[i__ + j * b_dim1] -= temp * b[i__ + k *
b_dim1];
/* L280: */
}
}
/* L290: */
}
if (alpha != 1.) {
i__1 = m;
for (i__ = 1; i__ <= i__1; ++i__) {
b[i__ + k * b_dim1] = alpha * b[i__ + k * b_dim1]
;
/* L300: */
}
}
/* L310: */
}
} else {
i__1 = n;
for (k = 1; k <= i__1; ++k) {
if (nounit) {
temp = 1. / a[k + k * a_dim1];
i__2 = m;
for (i__ = 1; i__ <= i__2; ++i__) {
b[i__ + k * b_dim1] = temp * b[i__ + k * b_dim1];
/* L320: */
}
}
i__2 = n;
for (j = k + 1; j <= i__2; ++j) {
if (a[j + k * a_dim1] != 0.) {
temp = a[j + k * a_dim1];
i__3 = m;
for (i__ = 1; i__ <= i__3; ++i__) {
b[i__ + j * b_dim1] -= temp * b[i__ + k *
b_dim1];
/* L330: */
}
}
/* L340: */
}
if (alpha != 1.) {
i__2 = m;
for (i__ = 1; i__ <= i__2; ++i__) {
b[i__ + k * b_dim1] = alpha * b[i__ + k * b_dim1]
;
/* L350: */
}
}
/* L360: */
}
}
}
}
}
// ------------------------------------------------------------------------------------
inline void cblas_trsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb)
{
#ifdef DLIB_USE_BLAS
cblas_strsm(Order, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb);
#else
local_trsm(Order, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb);
#endif
}
inline void cblas_trsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb)
{
#ifdef DLIB_USE_BLAS
cblas_dtrsm(Order, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb);
#else
local_trsm(Order, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb);
#endif
}
inline void cblas_trsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const long double alpha, const long double *A, const int lda,
long double *B, const int ldb)
{
local_trsm(Order, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb);
}
// ------------------------------------------------------------------------------------
template <
typename T,
long NR1, long NR2,
long NC1, long NC2,
typename MM
>
inline void triangular_solver (
const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag,
const matrix<T,NR1,NC1,MM,row_major_layout>& A,
const T alpha,
matrix<T,NR2,NC2,MM,row_major_layout>& B
)
{
cblas_trsm(CblasRowMajor, Side, Uplo, TransA, Diag, B.nr(), B.nc(),
alpha, &A(0,0), A.nc(), &B(0,0), B.nc());
}
// ------------------------------------------------------------------------------------
template <
typename T,
long NR1, long NR2,
long NC1, long NC2,
typename MM
>
inline void triangular_solver (
const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag,
const matrix<T,NR1,NC1,MM,column_major_layout>& A,
const T alpha,
matrix<T,NR2,NC2,MM,column_major_layout>& B
)
{
cblas_trsm(CblasColMajor, Side, Uplo, TransA, Diag, B.nr(), B.nc(),
alpha, &A(0,0), A.nr(), &B(0,0), B.nr());
}
// ------------------------------------------------------------------------------------
template <
typename T,
long NR1, long NR2,
long NC1, long NC2,
typename MM,
typename layout
>
inline void triangular_solver (
const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag,
const matrix<T,NR1,NC1,MM,layout>& A,
matrix<T,NR2,NC2,MM,layout>& B
)
{
const T alpha = 1;
triangular_solver(Side, Uplo, TransA, Diag, A, alpha, B);
}
// ------------------------------------------------------------------------------------
}
}
...@@ -25,6 +25,8 @@ namespace ...@@ -25,6 +25,8 @@ namespace
logger dlog("test.matrix2"); logger dlog("test.matrix2");
dlib::rand::float_1a rnd;
void matrix_test ( void matrix_test (
) )
/*! /*!
...@@ -370,16 +372,11 @@ namespace ...@@ -370,16 +372,11 @@ namespace
matrix<double, 7, 7,MM,column_major_layout> m7; matrix<double, 7, 7,MM,column_major_layout> m7;
matrix<double> dm7(7,7); matrix<double> dm7(7,7);
for (long r= 0; r< dm7.nr(); ++r) dm7 = randm(7,7, rnd);
{
for (long c = 0; c < dm7.nc(); ++c)
{
dm7(r,c) = r*c/3.3;
}
}
m7 = dm7; m7 = dm7;
DLIB_TEST(inv(dm7) == inv(m7)); DLIB_TEST_MSG(max(abs(dm7*inv(dm7) - identity_matrix<double>(7))) < 1e-12, max(abs(dm7*inv(dm7) - identity_matrix<double>(7))));
DLIB_TEST(equal(inv(dm7), inv(m7)));
DLIB_TEST(det(dm7) == det(m7)); DLIB_TEST(det(dm7) == det(m7));
DLIB_TEST(min(dm7) == min(m7)); DLIB_TEST(min(dm7) == min(m7));
DLIB_TEST(max(dm7) == max(m7)); DLIB_TEST(max(dm7) == max(m7));
......
...@@ -110,6 +110,17 @@ namespace ...@@ -110,6 +110,17 @@ namespace
DLIB_TEST(max(abs(test.get_imag_eigenvalues())) < eps); DLIB_TEST(max(abs(test.get_imag_eigenvalues())) < eps);
DLIB_TEST(diagm(diag(D)) == D); DLIB_TEST(diagm(diag(D)) == D);
// only check the determinant against the eigenvalues for small matrices
// because for huge ones the determinant might be so big it overflows a floating point number.
if (m.nr() < 50)
{
const type mdet = det(m);
DLIB_TEST_MSG(std::abs(prod(test.get_real_eigenvalues()) - mdet) < std::abs(mdet)*sqrt(std::numeric_limits<type>::epsilon()),
std::abs(prod(test.get_real_eigenvalues()) - mdet) <<" eps: " << std::abs(mdet)*sqrt(std::numeric_limits<type>::epsilon())
<< " mdet: "<< mdet << " prod(eig): " << prod(test.get_real_eigenvalues())
);
}
// V is orthogonal // V is orthogonal
DLIB_TEST(equal(V*trans(V), identity_matrix<type>(test.dim()), eps)); DLIB_TEST(equal(V*trans(V), identity_matrix<type>(test.dim()), eps));
DLIB_TEST(equal(m , V*D*trans(V), eps)); DLIB_TEST(equal(m , V*D*trans(V), eps));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment