From b68878b412e329844e075986213d924d3298c503 Mon Sep 17 00:00:00 2001
From: Davis King <davis@dlib.net>
Date: Thu, 10 Jul 2008 01:04:45 +0000
Subject: [PATCH] Optimized the pinv function a little

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402396
---
 dlib/matrix/matrix_utilities.h | 23 ++++++++++++++++-------
 dlib/test/matrix.cpp           |  1 +
 2 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/dlib/matrix/matrix_utilities.h b/dlib/matrix/matrix_utilities.h
index 378ceeef..e2155058 100644
--- a/dlib/matrix/matrix_utilities.h
+++ b/dlib/matrix/matrix_utilities.h
@@ -2412,19 +2412,28 @@ namespace dlib
     )
     { 
         typename matrix_exp<EXP>::matrix_type u;
-        matrix<typename EXP::type, EXP::NC, EXP::NC, typename EXP::mem_manager_type> w, v;
-        svd(m,u,w,v);
+        typedef typename EXP::mem_manager_type MM1;
+        matrix<typename EXP::type, EXP::NC, EXP::NC,MM1 > v;
+
+        typedef typename matrix_exp<EXP>::type T;
+
+        v.set_size(m.nc(),m.nc());
+
+        typedef typename matrix_exp<EXP>::type T;
+        u = m;
+
+        matrix<T,matrix_exp<EXP>::NC,1,MM1> w(m.nc(),1);
+        matrix<T,matrix_exp<EXP>::NC,1,MM1> rv1(m.nc(),1);
+
+        nric::svdcmp(u,w,v,rv1);
 
         const double machine_eps = std::numeric_limits<typename EXP::type>::epsilon();
         // compute a reasonable epsilon below which we round to zero before doing the
         // reciprocal
-        const double eps = machine_eps*std::max(m.nr(),m.nc())*max(diag(w));
-
-        // compute the reciprocal of the diagonal of w
-        matrix<typename EXP::type,EXP::NC,1, typename EXP::mem_manager_type> w_diag = reciprocal(round_zeros(diag(w),eps));
+        const double eps = machine_eps*std::max(m.nr(),m.nc())*max(w);
 
         // now compute the pseudoinverse
-        return tmp(scale_columns(v,w_diag))*trans(u);
+        return tmp(scale_columns(v,reciprocal(round_zeros(w,eps))))*trans(u);
     }
 
 // ----------------------------------------------------------------------------------------
diff --git a/dlib/test/matrix.cpp b/dlib/test/matrix.cpp
index 993e5ddc..ba1c2961 100644
--- a/dlib/test/matrix.cpp
+++ b/dlib/test/matrix.cpp
@@ -912,6 +912,7 @@ namespace
             matrix<double> mi = pinv(m ); 
             DLIB_CASSERT(mi.nr() == m.nc(),"");
             DLIB_CASSERT(mi.nc() == m.nr(),"");
+            DLIB_CASSERT((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,2>())),"");
         }
 
         {
-- 
2.18.0