Commit 8f443707 authored by Davis King's avatar Davis King

Added a tensor_product() function for the matrix object.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402570
parent d12f667e
......@@ -3835,6 +3835,46 @@ namespace dlib
}
// ----------------------------------------------------------------------------------------
struct op_tensor_product
{
template <typename EXP1, typename EXP2>
struct op : public has_destructive_aliasing
{
const static long NR = EXP1::NR*EXP2::NR;
const static long NC = EXP1::NC*EXP2::NC;
typedef typename EXP1::type type;
typedef typename EXP1::mem_manager_type mem_manager_type;
template <typename M1, typename M2>
static type apply ( const M1& m1, const M2& m2 , long r, long c)
{
return m1(r/m2.nr(),c/m2.nc())*m2(r%m2.nr(),c%m2.nc());
}
template <typename M1, typename M2>
static long nr (const M1& m1, const M2& m2 ) { return m1.nr()*m2.nr(); }
template <typename M1, typename M2>
static long nc (const M1& m1, const M2& m2 ) { return m1.nc()*m2.nc(); }
};
};
template <
typename EXP1,
typename EXP2
>
inline const matrix_exp<matrix_binary_exp<EXP1,EXP2,op_tensor_product> > tensor_product (
const matrix_exp<EXP1>& a,
const matrix_exp<EXP2>& b
)
{
COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
typedef matrix_binary_exp<EXP1,EXP2,op_tensor_product> exp;
return matrix_exp<exp>(exp(a.ref(),b.ref()));
}
// ----------------------------------------------------------------------------------------
}
......
......@@ -574,6 +574,25 @@ namespace dlib
performs pointwise_multiply(pointwise_multiply(a,b),pointwise_multiply(c,d));
!*/
// ----------------------------------------------------------------------------------------
const matrix_exp tensor_product (
const matrix_exp& a,
const matrix_exp& b
);
/*!
requires
- a and b both contain the same type of element
ensures
- returns a matrix R such that:
- R::type == the same type that was in a and b.
- R.nr() == a.nr()*b.nr()
- R.nc() == a.nc()*b.nc()
- for all valid r and c:
R(r,c) == a(r/b.nr(), c/b.nc()) * b(r%b.nr(), c%b.nc())
- I.e. R is the tensor product of matrix a with matrix b
!*/
// ----------------------------------------------------------------------------------------
const matrix_exp scale_columns (
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment