Commit a99fc566 authored by Davis King's avatar Davis King

Made svd_fast() accept a wider range of matrices as arguments.

parent 6391a034
......@@ -835,12 +835,17 @@ convergence:
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <typename T>
template <
typename T,
long Anr, long Anc,
typename MM,
typename L
>
void find_matrix_range (
const matrix<T>& A,
const matrix<T,Anr,Anc,MM,L>& A,
unsigned long l,
matrix<T>& Q,
unsigned long q = 0
matrix<T,Anr,0,MM,L>& Q,
unsigned long q
)
/*!
requires
......@@ -882,12 +887,20 @@ convergence:
// ----------------------------------------------------------------------------------------
template <typename T>
template <
typename T,
long Anr, long Anc,
long Unr, long Unc,
long Wnr, long Wnc,
long Vnr, long Vnc,
typename MM,
typename L
>
void svd_fast (
const matrix<T>& A,
matrix<T>& u,
matrix<T,0,1>& w,
matrix<T>& v,
const matrix<T,Anr,Anc,MM,L>& A,
matrix<T,Unr,Unc,MM,L>& u,
matrix<T,Wnr,Wnc,MM,L>& w,
matrix<T,Vnr,Vnc,MM,L>& v,
unsigned long l,
unsigned long q = 1
)
......@@ -901,26 +914,31 @@ convergence:
<< "\n\t A.size(): " << A.size()
);
matrix<T> Q;
matrix<T,Anr,0,MM,L> Q;
find_matrix_range(A, k, Q, q);
// Compute trans(B) = trans(Q)*A. The reason we store B transposed
// is so that when we take its SVD later using svd3() it doesn't consume
// a whole lot of RAM. That is, we make sure the square matrix coming out
// of svd3() has size lxl rather than the potentially much larger nxn.
matrix<T> B = trans(A)*Q;
matrix<T,0,0,MM,L> B = trans(A)*Q;
svd3(B, v,w,u);
u = Q*u;
}
// ----------------------------------------------------------------------------------------
template <typename sparse_vector_type, typename T>
template <
typename sparse_vector_type,
typename T,
typename MM,
typename L
>
void find_matrix_range (
const std::vector<sparse_vector_type>& A,
unsigned long l,
matrix<T>& Q,
unsigned long q = 0
matrix<T,0,0,MM,L>& Q,
unsigned long q
)
/*!
requires
......@@ -962,7 +980,7 @@ convergence:
const unsigned long n = max_index_plus_one(A);
for (unsigned long itr = 0; itr < q; ++itr)
{
matrix<T> Z(n, l);
matrix<T,0,0,MM,L> Z(n, l);
// Compute Z = trans(A)*Q
Z = 0;
for (unsigned long m = 0; m < A.size(); ++m)
......@@ -1001,12 +1019,20 @@ convergence:
// ----------------------------------------------------------------------------------------
template <typename sparse_vector_type, typename T>
template <
typename sparse_vector_type,
typename T,
long Unr, long Unc,
long Wnr, long Wnc,
long Vnr, long Vnc,
typename MM,
typename L
>
void svd_fast (
const std::vector<sparse_vector_type>& A,
matrix<T>& u,
matrix<T,0,1>& w,
matrix<T>& v,
matrix<T,Unr,Unc,MM,L>& u,
matrix<T,Wnr,Wnc,MM,L>& w,
matrix<T,Vnr,Vnc,MM,L>& v,
unsigned long l,
unsigned long q = 1
)
......@@ -1022,14 +1048,14 @@ convergence:
<< "\n\t A.size(): " << A.size()
);
matrix<T> Q;
matrix<T,0,0,MM,L> Q;
find_matrix_range(A, k, Q, q);
// Compute trans(B) = trans(Q)*A. The reason we store B transposed
// is so that when we take its SVD later using svd3() it doesn't consume
// a whole lot of RAM. That is, we make sure the square matrix coming out
// of svd3() has size lxl rather than the potentially much larger nxn.
matrix<T> B(n,k);
matrix<T,0,0,MM,L> B(n,k);
B = 0;
for (unsigned long m = 0; m < A.size(); ++m)
{
......
......@@ -141,7 +141,7 @@ namespace dlib
void svd_fast (
const matrix<T>& A,
matrix<T>& u,
matrix<T,0,1>& w,
matrix<T>& w,
matrix<T>& v,
unsigned long l,
unsigned long q = 1
......@@ -191,7 +191,7 @@ namespace dlib
void svd_fast (
const std::vector<sparse_vector_type>& A,
matrix<T>& u,
matrix<T,0,1>& w,
matrix<T>& w,
matrix<T>& v,
unsigned long l,
unsigned long q = 1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment