Commit 246a60c2 authored by Davis King's avatar Davis King

Added a version of get_q() to qr_decomposition that allows the user

to get the Q matrix by reference rather than by value.
parent 9376502c
...@@ -151,19 +151,44 @@ namespace dlib ...@@ -151,19 +151,44 @@ namespace dlib
typename T, typename T,
long NR1, long NR2, long NR3, long NR1, long NR2, long NR3,
long NC1, long NC2, long NC3, long NC1, long NC2, long NC3,
typename MM typename MM,
typename C_LAYOUT
> >
int ormqr ( int ormqr (
char side, char side,
char trans, char trans,
const matrix<T,NR1,NC1,MM,column_major_layout>& a, const matrix<T,NR1,NC1,MM,column_major_layout>& a,
const matrix<T,NR2,NC2,MM,column_major_layout>& tau, const matrix<T,NR2,NC2,MM,column_major_layout>& tau,
matrix<T,NR3,NC3,MM,column_major_layout>& c matrix<T,NR3,NC3,MM,C_LAYOUT>& c
) )
{ {
const long m = c.nr(); long m = c.nr();
const long n = c.nc(); long n = c.nc();
const long k = a.nc(); const long k = a.nc();
long ldc;
if (is_same_type<C_LAYOUT,column_major_layout>::value)
{
ldc = c.nr();
}
else
{
// Since lapack expects c to be in column major layout we have to
// do something to make this work. Since a row major layout matrix
// will look just like a transposed C we can just swap a few things around.
ldc = c.nc();
swap(m,n);
if (side == 'L')
side = 'R';
else
side = 'L';
if (trans == 'T')
trans = 'N';
else
trans = 'T';
}
matrix<T,0,1,MM,column_major_layout> work; matrix<T,0,1,MM,column_major_layout> work;
...@@ -171,7 +196,7 @@ namespace dlib ...@@ -171,7 +196,7 @@ namespace dlib
T work_size = 1; T work_size = 1;
int info = binding::ormqr(side, trans, m, n, int info = binding::ormqr(side, trans, m, n,
k, &a(0,0), a.nr(), &tau(0,0), k, &a(0,0), a.nr(), &tau(0,0),
&c(0,0), c.nr(), &work_size, -1); &c(0,0), ldc, &work_size, -1);
if (info != 0) if (info != 0)
return info; return info;
...@@ -182,7 +207,7 @@ namespace dlib ...@@ -182,7 +207,7 @@ namespace dlib
// compute the actual result // compute the actual result
info = binding::ormqr(side, trans, m, n, info = binding::ormqr(side, trans, m, n,
k, &a(0,0), a.nr(), &tau(0,0), k, &a(0,0), a.nr(), &tau(0,0),
&c(0,0), c.nr(), &work(0,0), work.size()); &c(0,0), ldc, &work(0,0), work.size());
return info; return info;
} }
......
...@@ -688,6 +688,16 @@ namespace dlib ...@@ -688,6 +688,16 @@ namespace dlib
- Q.nc() == nc() - Q.nc() == nc()
!*/ !*/
void get_q (
matrix_type& Q
) const;
/*!
ensures
- #Q == get_q()
- This function exists to allow a user to get the Q matrix without the
overhead of returning a matrix by value.
!*/
template <typename EXP> template <typename EXP>
const matrix_type solve ( const matrix_type solve (
const matrix_exp<EXP>& B const matrix_exp<EXP>& B
......
...@@ -62,6 +62,10 @@ namespace dlib ...@@ -62,6 +62,10 @@ namespace dlib
const matrix_type get_q ( const matrix_type get_q (
) const; ) const;
void get_q (
matrix_type& Q
) const;
template <typename EXP> template <typename EXP>
const matrix_type solve ( const matrix_type solve (
const matrix_exp<EXP>& B const matrix_exp<EXP>& B
...@@ -257,28 +261,39 @@ namespace dlib ...@@ -257,28 +261,39 @@ namespace dlib
const typename qr_decomposition<matrix_exp_type>::matrix_type qr_decomposition<matrix_exp_type>:: const typename qr_decomposition<matrix_exp_type>::matrix_type qr_decomposition<matrix_exp_type>::
get_q( get_q(
) const ) const
{
matrix_type Q;
get_q(Q);
return Q;
}
// ----------------------------------------------------------------------------------------
template <typename matrix_exp_type>
void qr_decomposition<matrix_exp_type>::
get_q(
matrix_type& X
) const
{ {
#ifdef DLIB_USE_LAPACK #ifdef DLIB_USE_LAPACK
matrix<type,0,0,mem_manager_type,column_major_layout> X;
// Take only the first n columns of an identity matrix. This way // Take only the first n columns of an identity matrix. This way
// X ends up being an m by n matrix. // X ends up being an m by n matrix.
X = colm(identity_matrix<type>(m), range(0,n-1)); X = colm(identity_matrix<type>(m), range(0,n-1));
// Compute Y = Q*X // Compute Y = Q*X
lapack::ormqr('L','N', QR_, tau, X); lapack::ormqr('L','N', QR_, tau, X);
return X;
#else #else
long i=0, j=0, k=0; long i=0, j=0, k=0;
matrix_type Q(m,n); X.set_size(m,n);
for (k = n-1; k >= 0; k--) for (k = n-1; k >= 0; k--)
{ {
for (i = 0; i < m; i++) for (i = 0; i < m; i++)
{ {
Q(i,k) = 0.0; X(i,k) = 0.0;
} }
Q(k,k) = 1.0; X(k,k) = 1.0;
for (j = k; j < n; j++) for (j = k; j < n; j++)
{ {
if (QR_(k,k) != 0) if (QR_(k,k) != 0)
...@@ -286,17 +301,16 @@ namespace dlib ...@@ -286,17 +301,16 @@ namespace dlib
type s = 0.0; type s = 0.0;
for (i = k; i < m; i++) for (i = k; i < m; i++)
{ {
s += QR_(i,k)*Q(i,j); s += QR_(i,k)*X(i,j);
} }
s = -s/QR_(k,k); s = -s/QR_(k,k);
for (i = k; i < m; i++) for (i = k; i < m; i++)
{ {
Q(i,j) += s*QR_(i,k); X(i,j) += s*QR_(i,k);
} }
} }
} }
} }
return Q;
#endif #endif
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment