From b68878b412e329844e075986213d924d3298c503 Mon Sep 17 00:00:00 2001 From: Davis King <davis@dlib.net> Date: Thu, 10 Jul 2008 01:04:45 +0000 Subject: [PATCH] Optimized the pinv function a little --HG-- extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402396 --- dlib/matrix/matrix_utilities.h | 23 ++++++++++++++++------- dlib/test/matrix.cpp | 1 + 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/dlib/matrix/matrix_utilities.h b/dlib/matrix/matrix_utilities.h index 378ceeef..e2155058 100644 --- a/dlib/matrix/matrix_utilities.h +++ b/dlib/matrix/matrix_utilities.h @@ -2412,19 +2412,28 @@ namespace dlib ) { typename matrix_exp<EXP>::matrix_type u; - matrix<typename EXP::type, EXP::NC, EXP::NC, typename EXP::mem_manager_type> w, v; - svd(m,u,w,v); + typedef typename EXP::mem_manager_type MM1; + matrix<typename EXP::type, EXP::NC, EXP::NC,MM1 > v; + + typedef typename matrix_exp<EXP>::type T; + + v.set_size(m.nc(),m.nc()); + + typedef typename matrix_exp<EXP>::type T; + u = m; + + matrix<T,matrix_exp<EXP>::NC,1,MM1> w(m.nc(),1); + matrix<T,matrix_exp<EXP>::NC,1,MM1> rv1(m.nc(),1); + + nric::svdcmp(u,w,v,rv1); const double machine_eps = std::numeric_limits<typename EXP::type>::epsilon(); // compute a reasonable epsilon below which we round to zero before doing the // reciprocal - const double eps = machine_eps*std::max(m.nr(),m.nc())*max(diag(w)); - - // compute the reciprocal of the diagonal of w - matrix<typename EXP::type,EXP::NC,1, typename EXP::mem_manager_type> w_diag = reciprocal(round_zeros(diag(w),eps)); + const double eps = machine_eps*std::max(m.nr(),m.nc())*max(w); // now compute the pseudoinverse - return tmp(scale_columns(v,w_diag))*trans(u); + return tmp(scale_columns(v,reciprocal(round_zeros(w,eps))))*trans(u); } // ---------------------------------------------------------------------------------------- diff --git a/dlib/test/matrix.cpp b/dlib/test/matrix.cpp index 993e5ddc..ba1c2961 100644 --- a/dlib/test/matrix.cpp +++ b/dlib/test/matrix.cpp @@ -912,6 +912,7 @@ namespace matrix<double> mi = pinv(m ); DLIB_CASSERT(mi.nr() == m.nc(),""); DLIB_CASSERT(mi.nc() == m.nr(),""); + DLIB_CASSERT((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,2>())),""); } { -- 2.18.0