Commit 7a479f80 authored by Davis King's avatar Davis King

Added a bunch of overloads to catch operations on diagonal matrices

and use more efficient code paths for them.  For example, inv(diagm(d))
turns into diagm(reciprocal(d)).

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403915
parent d031d605
......@@ -183,6 +183,23 @@ namespace dlib
is_matrix<T>::value == 1 if T is a matrix type else 0
*/
// ----------------------------------------------------------------------------------------
template <
typename EXP
>
class matrix_diag_exp : public matrix_exp<EXP>
{
/*!
This is a matrix expression type used to represent diagonal matrices.
That is, square matrices with all off diagonal elements equal to 0.
!*/
protected:
matrix_diag_exp() {}
matrix_diag_exp(const matrix_diag_exp& item ):matrix_exp<EXP>(item) {}
};
// ----------------------------------------------------------------------------------------
}
......
......@@ -971,6 +971,64 @@ convergence:
const matrix_exp<EXP>& m
) { return inv_helper<EXP,matrix_exp<EXP>::NR>::inv(m); }
// ----------------------------------------------------------------------------------------
template <typename M>
struct op_diag_inv
{
template <typename EXP>
op_diag_inv( const matrix_exp<EXP>& m_) : m(m_){}
const static long cost = 1;
const static long NR = (M::NC&&M::NR)? (tmax<M::NR,M::NC>::value) : (0);
const static long NC = NR;
typedef typename M::type type;
typedef const type const_ret_type;
typedef typename M::mem_manager_type mem_manager_type;
typedef typename M::layout_type layout_type;
// hold the matrix by value
const matrix<type,NR,1,mem_manager_type,layout_type> m;
const_ret_type apply ( long r, long c) const
{
if (r==c)
return m(r);
else
return 0;
}
long nr () const { return m.size(); }
long nc () const { return m.size(); }
template <typename U> bool aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const { return m.aliases(item); }
};
template <
typename EXP
>
const matrix_diag_op<op_diag_inv<EXP> > inv (
const matrix_diag_exp<EXP>& m
)
{
typedef op_diag_inv<EXP> op;
return matrix_diag_op<op>(op(reciprocal(diag(m))));
}
template <
typename EXP
>
const matrix_diag_op<op_diag_inv<EXP> > pinv (
const matrix_diag_exp<EXP>& m
)
{
typedef op_diag_inv<EXP> op;
return matrix_diag_op<op>(op(reciprocal(diag(m))));
}
// ----------------------------------------------------------------------------------------
template <typename EXP>
......
......@@ -84,6 +84,84 @@ namespace dlib
) const { return op.nc(); }
const OP op;
};
// ----------------------------------------------------------------------------------------
template <typename OP >
class matrix_diag_op;
template < typename OP >
struct matrix_traits<matrix_diag_op<OP> >
{
typedef typename OP::type type;
typedef typename OP::const_ret_type const_ret_type;
typedef typename OP::mem_manager_type mem_manager_type;
typedef typename OP::layout_type layout_type;
const static long NR = OP::NR;
const static long NC = OP::NC;
const static long cost = OP::cost;
};
template <
typename OP
>
class matrix_diag_op : public matrix_diag_exp<matrix_diag_op<OP> >
{
/*!
WHAT THIS OBJECT REPRESENTS
The matrix_diag_op is simply a tool for reducing the amount of boilerplate
you need to write when creating matrix expressions.
!*/
public:
typedef typename matrix_traits<matrix_diag_op>::type type;
typedef typename matrix_traits<matrix_diag_op>::const_ret_type const_ret_type;
typedef typename matrix_traits<matrix_diag_op>::mem_manager_type mem_manager_type;
typedef typename matrix_traits<matrix_diag_op>::layout_type layout_type;
const static long NR = matrix_traits<matrix_diag_op>::NR;
const static long NC = matrix_traits<matrix_diag_op>::NC;
const static long cost = matrix_traits<matrix_diag_op>::cost;
private:
// This constructor exists simply for the purpose of causing a compile time error if
// someone tries to create an instance of this object with the wrong kind of object.
template <typename T1>
matrix_diag_op (T1);
public:
matrix_diag_op (
const OP& op_
) :
op(op_)
{}
const_ret_type operator() (
long r,
long c
) const { return op.apply(r,c); }
const_ret_type operator() ( long i ) const
{ return matrix_exp<matrix_diag_op>::operator()(i); }
template <typename U>
bool aliases (
const matrix_exp<U>& item
) const { return op.aliases(item); }
template <typename U>
bool destructively_aliases (
const matrix_exp<U>& item
) const { return op.destructively_aliases(item); }
long nr (
) const { return op.nr(); }
long nc (
) const { return op.nc(); }
const OP op;
};
......
This diff is collapsed.
......@@ -369,6 +369,141 @@ namespace
m = trans(m)+1;
DLIB_TEST(m == m2);
}
{
matrix<double> d(3,1), di(3,1);
matrix<double> m(3,3);
m = 1,2,3,
4,5,6,
7,8,9;
d = 1,2,3;
di = 1, 1/2.0, 1/3.0;
DLIB_TEST(inv(diagm(d)) == diagm(di));
DLIB_TEST(pinv(diagm(d)) == diagm(di));
DLIB_TEST(inv(diagm(d))*m == tmp(diagm(di))*m);
DLIB_TEST(m*inv(diagm(d)) == m*tmp(diagm(di)));
DLIB_TEST(inv(diagm(d)) + m == tmp(diagm(di)) + m);
DLIB_TEST(m + inv(diagm(d)) == tmp(diagm(di)) + m);
DLIB_TEST((m + identity_matrix<double>(3) == m + tmp(identity_matrix<double>(3))));
DLIB_TEST((m + identity_matrix<double,3>() == m + tmp(identity_matrix<double,3>())));
DLIB_TEST((m + 2*identity_matrix<double>(3) == m + 2*tmp(identity_matrix<double>(3))));
DLIB_TEST((m + 2*identity_matrix<double,3>() == m + 2*tmp(identity_matrix<double,3>())));
DLIB_TEST((m + identity_matrix<double>(3)*2 == m + 2*tmp(identity_matrix<double>(3))));
DLIB_TEST((m + identity_matrix<double,3>()*2 == m + 2*tmp(identity_matrix<double,3>())));
DLIB_TEST((identity_matrix<double>(3) + m == m + tmp(identity_matrix<double>(3))));
DLIB_TEST((identity_matrix<double,3>() + m == m + tmp(identity_matrix<double,3>())));
DLIB_TEST((2*identity_matrix<double>(3) + m == m + 2*tmp(identity_matrix<double>(3))));
DLIB_TEST((2*identity_matrix<double,3>() + m == m + 2*tmp(identity_matrix<double,3>())));
}
{
matrix<double,3,1> d(3,1), di(3,1);
matrix<double,3,3> m(3,3);
m = 1,2,3,
4,5,6,
7,8,9;
d = 1,2,3;
di = 1, 1/2.0, 1/3.0;
DLIB_TEST(inv(diagm(d)) == diagm(di));
DLIB_TEST(inv(diagm(d)) == diagm(di));
DLIB_TEST(inv(diagm(d))*m == tmp(diagm(di))*m);
DLIB_TEST(m*inv(diagm(d)) == m*tmp(diagm(di)));
DLIB_TEST(inv(diagm(d)) + m == tmp(diagm(di)) + m);
DLIB_TEST(m + inv(diagm(d)) == tmp(diagm(di)) + m);
DLIB_TEST((m + identity_matrix<double>(3) == m + tmp(identity_matrix<double>(3))));
DLIB_TEST((m + identity_matrix<double,3>() == m + tmp(identity_matrix<double,3>())));
DLIB_TEST((m + 2*identity_matrix<double>(3) == m + 2*tmp(identity_matrix<double>(3))));
DLIB_TEST((m + 2*identity_matrix<double,3>() == m + 2*tmp(identity_matrix<double,3>())));
DLIB_TEST((m + identity_matrix<double>(3)*2 == m + 2*tmp(identity_matrix<double>(3))));
DLIB_TEST((m + identity_matrix<double,3>()*2 == m + 2*tmp(identity_matrix<double,3>())));
DLIB_TEST((identity_matrix<double>(3) + m == m + tmp(identity_matrix<double>(3))));
DLIB_TEST((identity_matrix<double,3>() + m == m + tmp(identity_matrix<double,3>())));
DLIB_TEST((2*identity_matrix<double>(3) + m == m + 2*tmp(identity_matrix<double>(3))));
DLIB_TEST((2*identity_matrix<double,3>() + m == m + 2*tmp(identity_matrix<double,3>())));
}
{
matrix<double,1,3> d(1,3), di(1,3);
matrix<double,3,3> m(3,3);
m = 1,2,3,
4,5,6,
7,8,9;
d = 1,2,3;
di = 1, 1/2.0, 1/3.0;
DLIB_TEST(inv(diagm(d)) == diagm(di));
DLIB_TEST(inv(diagm(d)) == diagm(di));
DLIB_TEST(inv(diagm(d))*m == tmp(diagm(di))*m);
DLIB_TEST(m*inv(diagm(d)) == m*tmp(diagm(di)));
DLIB_TEST(inv(diagm(d)) + m == tmp(diagm(di)) + m);
DLIB_TEST(m + inv(diagm(d)) == tmp(diagm(di)) + m);
DLIB_TEST((m + identity_matrix<double>(3) == m + tmp(identity_matrix<double>(3))));
DLIB_TEST((m + identity_matrix<double,3>() == m + tmp(identity_matrix<double,3>())));
DLIB_TEST((m + 2*identity_matrix<double>(3) == m + 2*tmp(identity_matrix<double>(3))));
DLIB_TEST((m + 2*identity_matrix<double,3>() == m + 2*tmp(identity_matrix<double,3>())));
DLIB_TEST((m + identity_matrix<double>(3)*2 == m + 2*tmp(identity_matrix<double>(3))));
DLIB_TEST((m + identity_matrix<double,3>()*2 == m + 2*tmp(identity_matrix<double,3>())));
DLIB_TEST((identity_matrix<double>(3) + m == m + tmp(identity_matrix<double>(3))));
DLIB_TEST((identity_matrix<double,3>() + m == m + tmp(identity_matrix<double,3>())));
DLIB_TEST((2*identity_matrix<double>(3) + m == m + 2*tmp(identity_matrix<double>(3))));
DLIB_TEST((2*identity_matrix<double,3>() + m == m + 2*tmp(identity_matrix<double,3>())));
}
{
matrix<double,1,0> d(1,3), di(1,3);
matrix<double,0,3> m(3,3);
m = 1,2,3,
4,5,6,
7,8,9;
d = 1,2,3;
di = 1, 1/2.0, 1/3.0;
DLIB_TEST(inv(diagm(d)) == diagm(di));
DLIB_TEST(inv(diagm(d)) == diagm(di));
DLIB_TEST(inv(diagm(d))*m == tmp(diagm(di))*m);
DLIB_TEST(m*inv(diagm(d)) == m*tmp(diagm(di)));
DLIB_TEST(inv(diagm(d)) + m == tmp(diagm(di)) + m);
DLIB_TEST(m + inv(diagm(d)) == tmp(diagm(di)) + m);
DLIB_TEST((m + identity_matrix<double>(3) == m + tmp(identity_matrix<double>(3))));
DLIB_TEST((m + identity_matrix<double,3>() == m + tmp(identity_matrix<double,3>())));
DLIB_TEST((m + 2*identity_matrix<double>(3) == m + 2*tmp(identity_matrix<double>(3))));
DLIB_TEST((m + 2*identity_matrix<double,3>() == m + 2*tmp(identity_matrix<double,3>())));
DLIB_TEST((m + identity_matrix<double>(3)*2 == m + 2*tmp(identity_matrix<double>(3))));
DLIB_TEST((m + identity_matrix<double,3>()*2 == m + 2*tmp(identity_matrix<double,3>())));
DLIB_TEST((identity_matrix<double>(3) + m == m + tmp(identity_matrix<double>(3))));
DLIB_TEST((identity_matrix<double,3>() + m == m + tmp(identity_matrix<double,3>())));
DLIB_TEST((2*identity_matrix<double>(3) + m == m + 2*tmp(identity_matrix<double>(3))));
DLIB_TEST((2*identity_matrix<double,3>() + m == m + 2*tmp(identity_matrix<double,3>())));
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment