Commit a99fc566 authored by Davis King's avatar Davis King

Made svd_fast() accept a wider range of matrices as arguments.

parent 6391a034
...@@ -835,12 +835,17 @@ convergence: ...@@ -835,12 +835,17 @@ convergence:
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename T> template <
typename T,
long Anr, long Anc,
typename MM,
typename L
>
void find_matrix_range ( void find_matrix_range (
const matrix<T>& A, const matrix<T,Anr,Anc,MM,L>& A,
unsigned long l, unsigned long l,
matrix<T>& Q, matrix<T,Anr,0,MM,L>& Q,
unsigned long q = 0 unsigned long q
) )
/*! /*!
requires requires
...@@ -882,12 +887,20 @@ convergence: ...@@ -882,12 +887,20 @@ convergence:
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename T> template <
typename T,
long Anr, long Anc,
long Unr, long Unc,
long Wnr, long Wnc,
long Vnr, long Vnc,
typename MM,
typename L
>
void svd_fast ( void svd_fast (
const matrix<T>& A, const matrix<T,Anr,Anc,MM,L>& A,
matrix<T>& u, matrix<T,Unr,Unc,MM,L>& u,
matrix<T,0,1>& w, matrix<T,Wnr,Wnc,MM,L>& w,
matrix<T>& v, matrix<T,Vnr,Vnc,MM,L>& v,
unsigned long l, unsigned long l,
unsigned long q = 1 unsigned long q = 1
) )
...@@ -901,26 +914,31 @@ convergence: ...@@ -901,26 +914,31 @@ convergence:
<< "\n\t A.size(): " << A.size() << "\n\t A.size(): " << A.size()
); );
matrix<T> Q; matrix<T,Anr,0,MM,L> Q;
find_matrix_range(A, k, Q, q); find_matrix_range(A, k, Q, q);
// Compute trans(B) = trans(Q)*A. The reason we store B transposed // Compute trans(B) = trans(Q)*A. The reason we store B transposed
// is so that when we take its SVD later using svd3() it doesn't consume // is so that when we take its SVD later using svd3() it doesn't consume
// a whole lot of RAM. That is, we make sure the square matrix coming out // a whole lot of RAM. That is, we make sure the square matrix coming out
// of svd3() has size lxl rather than the potentially much larger nxn. // of svd3() has size lxl rather than the potentially much larger nxn.
matrix<T> B = trans(A)*Q; matrix<T,0,0,MM,L> B = trans(A)*Q;
svd3(B, v,w,u); svd3(B, v,w,u);
u = Q*u; u = Q*u;
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename sparse_vector_type, typename T> template <
typename sparse_vector_type,
typename T,
typename MM,
typename L
>
void find_matrix_range ( void find_matrix_range (
const std::vector<sparse_vector_type>& A, const std::vector<sparse_vector_type>& A,
unsigned long l, unsigned long l,
matrix<T>& Q, matrix<T,0,0,MM,L>& Q,
unsigned long q = 0 unsigned long q
) )
/*! /*!
requires requires
...@@ -962,7 +980,7 @@ convergence: ...@@ -962,7 +980,7 @@ convergence:
const unsigned long n = max_index_plus_one(A); const unsigned long n = max_index_plus_one(A);
for (unsigned long itr = 0; itr < q; ++itr) for (unsigned long itr = 0; itr < q; ++itr)
{ {
matrix<T> Z(n, l); matrix<T,0,0,MM,L> Z(n, l);
// Compute Z = trans(A)*Q // Compute Z = trans(A)*Q
Z = 0; Z = 0;
for (unsigned long m = 0; m < A.size(); ++m) for (unsigned long m = 0; m < A.size(); ++m)
...@@ -1001,12 +1019,20 @@ convergence: ...@@ -1001,12 +1019,20 @@ convergence:
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename sparse_vector_type, typename T> template <
typename sparse_vector_type,
typename T,
long Unr, long Unc,
long Wnr, long Wnc,
long Vnr, long Vnc,
typename MM,
typename L
>
void svd_fast ( void svd_fast (
const std::vector<sparse_vector_type>& A, const std::vector<sparse_vector_type>& A,
matrix<T>& u, matrix<T,Unr,Unc,MM,L>& u,
matrix<T,0,1>& w, matrix<T,Wnr,Wnc,MM,L>& w,
matrix<T>& v, matrix<T,Vnr,Vnc,MM,L>& v,
unsigned long l, unsigned long l,
unsigned long q = 1 unsigned long q = 1
) )
...@@ -1022,14 +1048,14 @@ convergence: ...@@ -1022,14 +1048,14 @@ convergence:
<< "\n\t A.size(): " << A.size() << "\n\t A.size(): " << A.size()
); );
matrix<T> Q; matrix<T,0,0,MM,L> Q;
find_matrix_range(A, k, Q, q); find_matrix_range(A, k, Q, q);
// Compute trans(B) = trans(Q)*A. The reason we store B transposed // Compute trans(B) = trans(Q)*A. The reason we store B transposed
// is so that when we take its SVD later using svd3() it doesn't consume // is so that when we take its SVD later using svd3() it doesn't consume
// a whole lot of RAM. That is, we make sure the square matrix coming out // a whole lot of RAM. That is, we make sure the square matrix coming out
// of svd3() has size lxl rather than the potentially much larger nxn. // of svd3() has size lxl rather than the potentially much larger nxn.
matrix<T> B(n,k); matrix<T,0,0,MM,L> B(n,k);
B = 0; B = 0;
for (unsigned long m = 0; m < A.size(); ++m) for (unsigned long m = 0; m < A.size(); ++m)
{ {
......
...@@ -141,7 +141,7 @@ namespace dlib ...@@ -141,7 +141,7 @@ namespace dlib
void svd_fast ( void svd_fast (
const matrix<T>& A, const matrix<T>& A,
matrix<T>& u, matrix<T>& u,
matrix<T,0,1>& w, matrix<T>& w,
matrix<T>& v, matrix<T>& v,
unsigned long l, unsigned long l,
unsigned long q = 1 unsigned long q = 1
...@@ -191,7 +191,7 @@ namespace dlib ...@@ -191,7 +191,7 @@ namespace dlib
void svd_fast ( void svd_fast (
const std::vector<sparse_vector_type>& A, const std::vector<sparse_vector_type>& A,
matrix<T>& u, matrix<T>& u,
matrix<T,0,1>& w, matrix<T>& w,
matrix<T>& v, matrix<T>& v,
unsigned long l, unsigned long l,
unsigned long q = 1 unsigned long q = 1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment