Commit dd0dacb8 authored by Davis King's avatar Davis King

Gave pinv() an optional tolerance option.

parent 27136609
...@@ -1290,6 +1290,25 @@ convergence: ...@@ -1290,6 +1290,25 @@ convergence:
return matrix_diag_op<op>(op(reciprocal(diag(m)))); return matrix_diag_op<op>(op(reciprocal(diag(m))));
} }
// ----------------------------------------------------------------------------------------
template <
typename EXP
>
const matrix_diag_op<op_diag_inv<EXP> > pinv (
const matrix_diag_exp<EXP>& m,
double tol
)
{
DLIB_ASSERT(tol >= 0,
"\tconst matrix_exp::type pinv(const matrix_exp& m)"
<< "\n\t tol can't be negative"
<< "\n\t tol: "<<tol
);
typedef op_diag_inv<EXP> op;
return matrix_diag_op<op>(op(reciprocal(round_zeros(diag(m),tol))));
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename EXP> template <typename EXP>
...@@ -1519,7 +1538,8 @@ convergence: ...@@ -1519,7 +1538,8 @@ convergence:
typename EXP typename EXP
> >
const matrix<typename EXP::type,EXP::NC,EXP::NR,typename EXP::mem_manager_type> pinv_helper ( const matrix<typename EXP::type,EXP::NC,EXP::NR,typename EXP::mem_manager_type> pinv_helper (
const matrix_exp<EXP>& m const matrix_exp<EXP>& m,
double tol
) )
/*! /*!
ensures ensures
...@@ -1541,8 +1561,8 @@ convergence: ...@@ -1541,8 +1561,8 @@ convergence:
const double machine_eps = std::numeric_limits<typename EXP::type>::epsilon(); const double machine_eps = std::numeric_limits<typename EXP::type>::epsilon();
// compute a reasonable epsilon below which we round to zero before doing the // compute a reasonable epsilon below which we round to zero before doing the
// reciprocal // reciprocal. Unless a non-zero tol is given then we just use tol.
const double eps = machine_eps*std::max(m.nr(),m.nc())*max(w); const double eps = (tol!=0) ? tol : machine_eps*std::max(m.nr(),m.nc())*max(w);
// now compute the pseudoinverse // now compute the pseudoinverse
return tmp(scale_columns(v,reciprocal(round_zeros(w,eps))))*trans(u); return tmp(scale_columns(v,reciprocal(round_zeros(w,eps))))*trans(u);
...@@ -1552,15 +1572,21 @@ convergence: ...@@ -1552,15 +1572,21 @@ convergence:
typename EXP typename EXP
> >
const matrix<typename EXP::type,EXP::NC,EXP::NR,typename EXP::mem_manager_type> pinv ( const matrix<typename EXP::type,EXP::NC,EXP::NR,typename EXP::mem_manager_type> pinv (
const matrix_exp<EXP>& m const matrix_exp<EXP>& m,
double tol = 0
) )
{ {
DLIB_ASSERT(tol >= 0,
"\tconst matrix_exp::type pinv(const matrix_exp& m)"
<< "\n\t tol can't be negative"
<< "\n\t tol: "<<tol
);
// if m has more columns then rows then it is more efficient to // if m has more columns then rows then it is more efficient to
// compute the pseudo-inverse of its transpose (given the way I'm doing it below). // compute the pseudo-inverse of its transpose (given the way I'm doing it below).
if (m.nc() > m.nr()) if (m.nc() > m.nr())
return trans(pinv_helper(trans(m))); return trans(pinv_helper(trans(m),tol));
else else
return pinv_helper(m); return pinv_helper(m,tol);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -31,12 +31,20 @@ namespace dlib ...@@ -31,12 +31,20 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
const matrix pinv ( const matrix pinv (
const matrix_exp& m const matrix_exp& m,
double tol = 0
); );
/*! /*!
requires
- tol >= 0
ensures ensures
- returns the Moore-Penrose pseudoinverse of m. - returns the Moore-Penrose pseudoinverse of m.
- The returned matrix has m.nc() rows and m.nr() columns. - The returned matrix has m.nc() rows and m.nr() columns.
- if (tol == 0) then
- singular values less than max(m.nr(),m.nc()) times the machine epsilon
times the largest singular value are ignored.
- else
- singular values less than tol are ignored.
!*/ !*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -67,6 +67,40 @@ namespace ...@@ -67,6 +67,40 @@ namespace
DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix<double,5>()))); DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix<double,5>())));
DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m)))); DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m))));
DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m)))); DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m))));
mi = pinv(m,1e-12);
DLIB_TEST(mi.nr() == m.nc());
DLIB_TEST(mi.nc() == m.nr());
DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,5>())));
DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix<double,5>())));
DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m))));
DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m))));
m = diagm(diag(m));
mi = pinv(diagm(diag(m)),1e-12);
DLIB_TEST(mi.nr() == m.nc());
DLIB_TEST(mi.nc() == m.nr());
DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,5>())));
DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix<double,5>())));
DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m))));
DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m))));
mi = pinv(m,0);
DLIB_TEST(mi.nr() == m.nc());
DLIB_TEST(mi.nc() == m.nr());
DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,5>())));
DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix<double,5>())));
DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m))));
DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m))));
m = diagm(diag(m));
mi = pinv(diagm(diag(m)),0);
DLIB_TEST(mi.nr() == m.nc());
DLIB_TEST(mi.nc() == m.nr());
DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix<double,5>())));
DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix<double,5>())));
DLIB_TEST((equal(round_zeros(mi*m,0.000001) , identity_matrix(m))));
DLIB_TEST((equal(round_zeros(m*mi,0.000001) , identity_matrix(m))));
} }
{ {
matrix<double,5,0,MM> m(5,5); matrix<double,5,0,MM> m(5,5);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment