Commit e9df8898 authored by Davis King's avatar Davis King

- Added a missing operator*() for diagonal by diagonal matrix multiplication.

   Without it you would get an error about multiplication being ambiguous
   in this case.
 - Added an overload to catch expressions of the form diag_matrix*regular_matrix*diag_matrix
   and turn them into a form which is slightly more numerically stable in some cases.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403932
parent 5faadcbf
...@@ -531,6 +531,20 @@ namespace dlib ...@@ -531,6 +531,20 @@ namespace dlib
typedef op_trans<M> op; typedef op_trans<M> op;
return matrix_op<op>(op(m.ref())); return matrix_op<op>(op(m.ref()));
} }
// ----------------------------------------------------------------------------------------
// don't to anything at all for diagonal matrices
template <
typename M
>
const matrix_diag_exp<M>& trans (
const matrix_diag_exp<M>& m
)
{
return m;
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// I introduced this struct because it avoids an inane compiler warning from gcc // I introduced this struct because it avoids an inane compiler warning from gcc
...@@ -1020,6 +1034,51 @@ namespace dlib ...@@ -1020,6 +1034,51 @@ namespace dlib
return matrix_diag_op<op>(op(m.ref())); return matrix_diag_op<op>(op(m.ref()));
} }
// ----------------------------------------------------------------------------------------
template <typename M1, typename M2>
struct op_diagm_mult : basic_op_mm<M1,M2>
{
op_diagm_mult( const M1& m1_, const M2& m2_) : basic_op_mm<M1,M2>(m1_,m2_){}
typedef typename M1::type type;
typedef const type const_ret_type;
const static long cost = M1::cost + M2::cost + 1;
const_ret_type apply ( long r, long c) const
{
if (r == c)
return this->m1(r,c)*this->m2(r,c);
else
return 0;
}
};
template <
typename EXP1,
typename EXP2
>
inline const matrix_diag_op<op_diagm_mult<EXP1,EXP2> > operator* (
const matrix_diag_exp<EXP1>& a,
const matrix_diag_exp<EXP2>& b
)
{
COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type, typename EXP2::type>::value));
COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0);
COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NC == 0 || EXP2::NC == 0);
DLIB_ASSERT(a.nr() == b.nr() &&
a.nc() == b.nc(),
"\tconst matrix_exp operator(const matrix_diag_exp& a, const matrix_diag_exp& b)"
<< "\n\tYou can only multiply diagonal matrices together if they are the same size"
<< "\n\ta.nr(): " << a.nr()
<< "\n\ta.nc(): " << a.nc()
<< "\n\tb.nr(): " << b.nr()
<< "\n\tb.nc(): " << b.nc()
);
typedef op_diagm_mult<EXP1,EXP2> op;
return matrix_diag_op<op>(op(a.ref(),b.ref()));
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename M> template <typename M>
...@@ -2897,21 +2956,6 @@ namespace dlib ...@@ -2897,21 +2956,6 @@ namespace dlib
return matrix_op<op>(op(m.ref(),v.ref())); return matrix_op<op>(op(m.ref(),v.ref()));
} }
// ----------------------------------------------------------------------------------------
// turn expressions of the form mat*diagm(v) into scale_columns(mat, v)
template <
typename EXP1,
typename EXP2
>
const matrix_op<op_scale_columns<EXP1,EXP2> > operator* (
const matrix_exp<EXP1>& m,
const matrix_exp<matrix_diag_op<op_diagm<EXP2> > >& v
)
{
return scale_columns(m,v.ref().op.m);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename M1, typename M2> template <typename M1, typename M2>
...@@ -3031,21 +3075,6 @@ namespace dlib ...@@ -3031,21 +3075,6 @@ namespace dlib
return matrix_op<op>(op(m.ref(),v.ref())); return matrix_op<op>(op(m.ref(),v.ref()));
} }
// ----------------------------------------------------------------------------------------
// turn expressions of the form diagm(v)*mat into scale_rows(mat, v)
template <
typename EXP1,
typename EXP2
>
const matrix_op<op_scale_rows<EXP1,EXP2> > operator* (
const matrix_exp<matrix_diag_op<op_diagm<EXP2> > >& v,
const matrix_exp<EXP1>& m
)
{
return scale_rows(m,v.ref().op.m);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename M1, typename M2> template <typename M1, typename M2>
...@@ -3105,6 +3134,113 @@ namespace dlib ...@@ -3105,6 +3134,113 @@ namespace dlib
return matrix_op<op>(op(m.ref(),d.ref())); return matrix_op<op>(op(m.ref(),d.ref()));
} }
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
/*
The idea here is to catch expressions of the form d*M*d where d is diagonal and M
is some square matrix and turn them into something equivalent to
pointwise_multiply(diag(d)*trans(diag(d)), M).
The reason for this is that doing it this way is more numerically stable. In particular,
doing 2 matrix multiplies as suggested by d*M*d could result in an asymmetric matrix even
if M is symmetric to begin with.
*/
template <typename M1, typename M2, typename M3>
struct op_diag_m_diag
{
// This operator represents M1*M2*M3 where M1 and M3 are diagonal
op_diag_m_diag(const M1& m1_, const M2& m2_, const M3& m3_) : m1(m1_), m2(m2_), m3(m3_) {}
const M1& m1;
const M2& m2;
const M3& m3;
const static long cost = M1::cost + M2::cost + M3::cost + 1;
typedef typename M2::type type;
typedef const typename M2::type const_ret_type;
typedef typename M2::mem_manager_type mem_manager_type;
typedef typename M2::layout_type layout_type;
const static long NR = M2::NR;
const static long NC = M2::NC;
const_ret_type apply ( long r, long c) const { return (m1(r,r)*m3(c,c))*m2(r,c); }
long nr () const { return m2.nr(); }
long nc () const { return m2.nc(); }
template <typename U> bool aliases ( const matrix_exp<U>& item) const
{ return m1.aliases(item) || m2.aliases(item) || m3.aliases(item) ; }
template <typename U> bool destructively_aliases ( const matrix_exp<U>& item) const
{ return m2.destructively_aliases(item) || m1.aliases(item) || m3.aliases(item) ; }
};
// catch d*(M*d) = EXP1*EXP2*EXP3
template <
typename EXP1,
typename EXP2,
typename EXP3
>
const matrix_op<op_diag_m_diag<EXP1,EXP2,EXP3> > operator* (
const matrix_diag_exp<EXP1>& d,
const matrix_exp<matrix_op<op_scale_columns_diag<EXP2,EXP3> > >& m
)
{
// Both arguments to this function must contain the same type of element
COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
// figure out the compile time known length of d
const long v_len = ((EXP1::NR)*(EXP1::NC) == 0)? 0 : (tmax<EXP1::NR,EXP1::NC>::value);
// the length of d must match the number of rows in m
COMPILE_TIME_ASSERT(EXP2::NR == v_len || EXP2::NR == 0 || v_len == 0);
DLIB_ASSERT(d.nc() == m.nr(),
"\tconst matrix_exp operator*(d, m)"
<< "\n\tmatrix dimensions don't match"
<< "\n\td.nr(): " << d.nr()
<< "\n\td.nc(): " << d.nc()
<< "\n\tm.nr(): " << m.nr()
<< "\n\tm.nc(): " << m.nc()
);
typedef op_diag_m_diag<EXP1,EXP2,EXP3> op;
return matrix_op<op>(op(d.ref(), m.ref().op.m1, m.ref().op.m2));
}
// catch (d*M)*d = EXP1*EXP2*EXP3
template <
typename EXP1,
typename EXP2,
typename EXP3
>
const matrix_op<op_diag_m_diag<EXP1,EXP2,EXP3> > operator* (
const matrix_exp<matrix_op<op_scale_rows_diag<EXP2,EXP1> > >& m,
const matrix_diag_exp<EXP3>& d
)
{
// Both arguments to this function must contain the same type of element
COMPILE_TIME_ASSERT((is_same_type<typename EXP3::type,typename EXP2::type>::value == true));
// figure out the compile time known length of d
const long v_len = ((EXP3::NR)*(EXP3::NC) == 0)? 0 : (tmax<EXP3::NR,EXP3::NC>::value);
// the length of d must match the number of columns in m
COMPILE_TIME_ASSERT(EXP2::NC == v_len || EXP2::NC == 0 || v_len == 0);
DLIB_ASSERT(m.nc() == d.nr(),
"\tconst matrix_exp operator*(m, d)"
<< "\n\tmatrix dimensions don't match"
<< "\n\tm.nr(): " << m.nr()
<< "\n\tm.nc(): " << m.nc()
<< "\n\td.nr(): " << d.nr()
<< "\n\td.nc(): " << d.nc()
);
typedef op_diag_m_diag<EXP1,EXP2,EXP3> op;
return matrix_op<op>(op(m.ref().op.m2, m.ref().op.m1, d.ref()));
}
// ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
struct sort_columns_sort_helper struct sort_columns_sort_helper
......
...@@ -504,6 +504,94 @@ namespace ...@@ -504,6 +504,94 @@ namespace
DLIB_TEST((2*identity_matrix<double>(3) + m == m + 2*tmp(identity_matrix<double>(3)))); DLIB_TEST((2*identity_matrix<double>(3) + m == m + 2*tmp(identity_matrix<double>(3))));
DLIB_TEST((2*identity_matrix<double,3>() + m == m + 2*tmp(identity_matrix<double,3>()))); DLIB_TEST((2*identity_matrix<double,3>() + m == m + 2*tmp(identity_matrix<double,3>())));
} }
{
matrix<double,3,1> d1, d2;
d1 = 1,2,3;
d2 = 2,3,4;
matrix<double,3,3> ans;
ans = 2, 0, 0,
0, 6, 0,
0, 0, 12;
DLIB_TEST(ans == diagm(d1)*diagm(d2));
}
dlib::rand::float_1a rnd;
for (int i = 0; i < 1; ++i)
{
matrix<double> d1 = randm(4,1,rnd);
matrix<double,5,1> d2 = randm(5,1,rnd);
matrix<double,4,5> m = randm(4,5,rnd);
DLIB_TEST(pointwise_multiply(d1*trans(d2), m) == diagm(d1)*m*diagm(d2));
DLIB_TEST(pointwise_multiply(d1*trans(d2), m) == diagm(d1)*(m*diagm(d2)));
DLIB_TEST(pointwise_multiply(d1*trans(d2), m) == (diagm(d1)*m)*diagm(d2));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) == inv(diagm(d1))*m*inv(diagm(d2)));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) == inv(diagm(d1))*(m*inv(diagm(d2))));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) == (inv(diagm(d1))*m)*inv(diagm(d2)));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans((d2)), m) == inv(diagm(d1))*m*(diagm(d2)));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans((d2)), m) == inv(diagm(d1))*(m*(diagm(d2))));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans((d2)), m) == (inv(diagm(d1))*m)*(diagm(d2)));
DLIB_TEST(pointwise_multiply((d1)*trans(reciprocal(d2)), m) == (diagm(d1))*m*inv(diagm(d2)));
DLIB_TEST(pointwise_multiply((d1)*trans(reciprocal(d2)), m) == (diagm(d1))*(m*inv(diagm(d2))));
DLIB_TEST(pointwise_multiply((d1)*trans(reciprocal(d2)), m) == ((diagm(d1))*m)*inv(diagm(d2)));
}
for (int i = 0; i < 1; ++i)
{
matrix<double,4,1> d1 = randm(4,1,rnd);
matrix<double,5,1> d2 = randm(5,1,rnd);
matrix<double,4,5> m = randm(4,5,rnd);
DLIB_TEST(pointwise_multiply(d1*trans(d2), m) == diagm(d1)*m*diagm(d2));
DLIB_TEST(pointwise_multiply(d1*trans(d2), m) == diagm(d1)*(m*diagm(d2)));
DLIB_TEST(pointwise_multiply(d1*trans(d2), m) == (diagm(d1)*m)*diagm(d2));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) == inv(diagm(d1))*m*inv(diagm(d2)));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) == inv(diagm(d1))*(m*inv(diagm(d2))));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) == (inv(diagm(d1))*m)*inv(diagm(d2)));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans((d2)), m) == inv(diagm(d1))*m*(diagm(d2)));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans((d2)), m) == inv(diagm(d1))*(m*(diagm(d2))));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans((d2)), m) == (inv(diagm(d1))*m)*(diagm(d2)));
DLIB_TEST(pointwise_multiply((d1)*trans(reciprocal(d2)), m) == (diagm(d1))*m*inv(diagm(d2)));
DLIB_TEST(pointwise_multiply((d1)*trans(reciprocal(d2)), m) == (diagm(d1))*(m*inv(diagm(d2))));
DLIB_TEST(pointwise_multiply((d1)*trans(reciprocal(d2)), m) == ((diagm(d1))*m)*inv(diagm(d2)));
}
for (int i = 0; i < 1; ++i)
{
matrix<double,4,1> d1 = randm(4,1,rnd);
matrix<double,5,1> d2 = randm(5,1,rnd);
matrix<double,0,0> m = randm(4,5,rnd);
DLIB_TEST(pointwise_multiply(d1*trans(d2), m) == diagm(d1)*m*diagm(d2));
DLIB_TEST(pointwise_multiply(d1*trans(d2), m) == diagm(d1)*(m*diagm(d2)));
DLIB_TEST(pointwise_multiply(d1*trans(d2), m) == (diagm(d1)*m)*diagm(d2));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) == inv(diagm(d1))*m*inv(diagm(d2)));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) == inv(diagm(d1))*(m*inv(diagm(d2))));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans(reciprocal(d2)), m) == (inv(diagm(d1))*m)*inv(diagm(d2)));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans((d2)), m) == inv(diagm(d1))*m*(diagm(d2)));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans((d2)), m) == inv(diagm(d1))*(m*(diagm(d2))));
DLIB_TEST(pointwise_multiply(reciprocal(d1)*trans((d2)), m) == (inv(diagm(d1))*m)*(diagm(d2)));
DLIB_TEST(pointwise_multiply((d1)*trans(reciprocal(d2)), m) == (diagm(d1))*m*inv(diagm(d2)));
DLIB_TEST(pointwise_multiply((d1)*trans(reciprocal(d2)), m) == (diagm(d1))*(m*inv(diagm(d2))));
DLIB_TEST(pointwise_multiply((d1)*trans(reciprocal(d2)), m) == ((diagm(d1))*m)*inv(diagm(d2)));
}
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment