Commit 887f979b authored by Davis King's avatar Davis King

Added stuff to the BLAS bindings so that they can now bind

to statements with matrix sub expressions as the destination
matrix.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402775
parent bfa4b3f2
...@@ -25,6 +25,22 @@ namespace dlib ...@@ -25,6 +25,22 @@ namespace dlib
namespace blas_bindings namespace blas_bindings
{ {
// ------------------------------------------------------------------------------------
template <typename T>
void zero_matrix (
T& m
)
{
for (long r = 0; r < m.nr(); ++r)
{
for (long c = 0; c < m.nc(); ++c)
{
m(r,c) = 0;
}
}
}
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// This template struct is used to tell us if a matrix expression contains a matrix multiply. // This template struct is used to tell us if a matrix expression contains a matrix multiply.
...@@ -187,7 +203,7 @@ namespace dlib ...@@ -187,7 +203,7 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template < template <
typename T, long NR, long NC, typename MM, typename L, typename dest_exp,
typename src_exp, typename src_exp,
typename enabled = void typename enabled = void
> >
...@@ -198,7 +214,7 @@ namespace dlib ...@@ -198,7 +214,7 @@ namespace dlib
// let the default matrix assignment happen. // let the default matrix assignment happen.
template <typename EXP> template <typename EXP>
static void assign ( static void assign (
matrix<T,NR,NC,MM,L>& dest, dest_exp& dest,
const EXP& src, const EXP& src,
typename src_exp::type alpha, typename src_exp::type alpha,
bool add_to bool add_to
...@@ -212,7 +228,7 @@ namespace dlib ...@@ -212,7 +228,7 @@ namespace dlib
// than the above default function would. // than the above default function would.
template <typename EXP1, typename EXP2> template <typename EXP1, typename EXP2>
static void assign ( static void assign (
matrix<T,NR,NC,MM,L>& dest, dest_exp& dest,
const matrix_multiply_exp<EXP1,EXP2>& src, const matrix_multiply_exp<EXP1,EXP2>& src,
typename src_exp::type alpha, typename src_exp::type alpha,
bool add_to bool add_to
...@@ -229,7 +245,7 @@ namespace dlib ...@@ -229,7 +245,7 @@ namespace dlib
} }
else else
{ {
set_all_elements(dest,0); zero_matrix(dest);
default_matrix_multiply(dest, src.lhs, src.rhs); default_matrix_multiply(dest, src.lhs, src.rhs);
} }
} }
...@@ -237,16 +253,16 @@ namespace dlib ...@@ -237,16 +253,16 @@ namespace dlib
{ {
if (add_to) if (add_to)
{ {
matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc()); typename dest_exp::matrix_type temp(dest.nr(),dest.nc());
set_all_elements(temp,0); zero_matrix(temp);
default_matrix_multiply(temp, src.lhs, src.rhs); default_matrix_multiply(temp, src.lhs, src.rhs);
dest += alpha*temp; matrix_assign_default(dest,temp, alpha,true);
} }
else else
{ {
set_all_elements(dest,0); zero_matrix(dest);
default_matrix_multiply(dest, src.lhs, src.rhs); default_matrix_multiply(dest, src.lhs, src.rhs);
dest = alpha*dest; matrix_assign_default(dest,dest, alpha, false);
} }
} }
} }
...@@ -254,19 +270,21 @@ namespace dlib ...@@ -254,19 +270,21 @@ namespace dlib
// This is a macro to help us add overloads for the matrix_assign_blas_helper template. // This is a macro to help us add overloads for the matrix_assign_blas_helper template.
// Using this macro it is easy to add overloads for arbitrary matrix expressions. // Using this macro it is easy to add overloads for arbitrary matrix expressions.
#define DLIB_ADD_BLAS_BINDING(src_expression) \ #define DLIB_ADD_BLAS_BINDING(src_expression) \
template <typename T, typename L> struct BOOST_JOIN(blas,__LINE__) \ template <typename T, typename L> struct BOOST_JOIN(blas,__LINE__) \
{ const static bool value = sizeof(yes_type) == sizeof(test<T,L>(src_expression)); }; \ { const static bool value = sizeof(yes_type) == sizeof(test<T,L>(src_expression)); }; \
template < typename T, long NR, long NC, typename MM, typename L, typename src_exp >\ \
struct matrix_assign_blas_helper<T,NR,NC,MM,L, src_exp, \ template < typename dest_exp, typename src_exp > \
typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp,L> >::type > { \ struct matrix_assign_blas_helper<dest_exp, src_exp, \
static void assign ( \ typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp,typename dest_exp::layout_type> >::type > { \
matrix<T,NR,NC,MM,L>& dest, \ static void assign ( \
const src_exp& src, \ dest_exp& dest, \
typename src_exp::type alpha, \ const src_exp& src, \
bool add_to \ typename src_exp::type alpha, \
) { \ bool add_to \
const bool is_row_major_order = is_same_type<L,row_major_layout>::value; ) { \
typedef typename dest_exp::type T; \
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
#define DLIB_END_BLAS_BINDING }}; #define DLIB_END_BLAS_BINDING }};
...@@ -277,11 +295,11 @@ namespace dlib ...@@ -277,11 +295,11 @@ namespace dlib
// ------------------- Forward Declarations ------------------- // ------------------- Forward Declarations -------------------
template < template <
typename T, long NR, long NC, typename MM, typename L, typename dest_exp,
typename src_exp typename src_exp
> >
void matrix_assign_blas_proxy ( void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest, dest_exp& dest,
const src_exp& src, const src_exp& src,
typename src_exp::type alpha, typename src_exp::type alpha,
bool add_to bool add_to
...@@ -294,11 +312,11 @@ namespace dlib ...@@ -294,11 +312,11 @@ namespace dlib
!*/ !*/
template < template <
typename T, long NR, long NC, typename MM, typename L, typename dest_exp,
typename src_exp, typename src_exp2 typename src_exp, typename src_exp2
> >
void matrix_assign_blas_proxy ( void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest, dest_exp& dest,
const matrix_add_exp<src_exp, src_exp2>& src, const matrix_add_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha, typename src_exp::type alpha,
bool add_to bool add_to
...@@ -311,11 +329,11 @@ namespace dlib ...@@ -311,11 +329,11 @@ namespace dlib
!*/ !*/
template < template <
typename T, long NR, long NC, typename MM, typename L, typename dest_exp,
typename src_exp, bool Sb typename src_exp, bool Sb
> >
void matrix_assign_blas_proxy ( void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest, dest_exp& dest,
const matrix_mul_scal_exp<src_exp,Sb>& src, const matrix_mul_scal_exp<src_exp,Sb>& src,
typename src_exp::type alpha, typename src_exp::type alpha,
bool add_to bool add_to
...@@ -328,11 +346,11 @@ namespace dlib ...@@ -328,11 +346,11 @@ namespace dlib
!*/ !*/
template < template <
typename T, long NR, long NC, typename MM, typename L, typename dest_exp,
typename src_exp, typename src_exp2 typename src_exp, typename src_exp2
> >
void matrix_assign_blas_proxy ( void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest, dest_exp& dest,
const matrix_subtract_exp<src_exp, src_exp2>& src, const matrix_subtract_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha, typename src_exp::type alpha,
bool add_to bool add_to
...@@ -370,6 +388,22 @@ namespace dlib ...@@ -370,6 +388,22 @@ namespace dlib
This is an important case to catch because it is the expression used This is an important case to catch because it is the expression used
to represent the += matrix operator. to represent the += matrix operator.
!*/ !*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<src_exp, matrix<T,NR,NC,MM,L> >& src
);
/*!
This function catches the expressions of the form:
M = exp + M;
and converts them into the appropriate matrix_assign_blas() call.
This is an important case to catch because it is the expression used
to represent the += matrix operator.
!*/
template < template <
typename T, long NR, long NC, typename MM, typename L, typename T, long NR, long NC, typename MM, typename L,
...@@ -395,27 +429,27 @@ namespace dlib ...@@ -395,27 +429,27 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template < template <
typename T, long NR, long NC, typename MM, typename L, typename dest_exp,
typename src_exp typename src_exp
> >
void matrix_assign_blas_proxy ( void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest, dest_exp& dest,
const src_exp& src, const src_exp& src,
typename src_exp::type alpha, typename src_exp::type alpha,
bool add_to bool add_to
) )
{ {
matrix_assign_blas_helper<T,NR,NC,MM,L,src_exp>::assign(dest,src,alpha,add_to); matrix_assign_blas_helper<dest_exp,src_exp>::assign(dest,src,alpha,add_to);
} }
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template < template <
typename T, long NR, long NC, typename MM, typename L, typename dest_exp,
typename src_exp, typename src_exp2 typename src_exp, typename src_exp2
> >
void matrix_assign_blas_proxy ( void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest, dest_exp& dest,
const matrix_add_exp<src_exp, src_exp2>& src, const matrix_add_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha, typename src_exp::type alpha,
bool add_to bool add_to
...@@ -435,11 +469,11 @@ namespace dlib ...@@ -435,11 +469,11 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template < template <
typename T, long NR, long NC, typename MM, typename L, typename dest_exp,
typename src_exp, bool Sb typename src_exp, bool Sb
> >
void matrix_assign_blas_proxy ( void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest, dest_exp& dest,
const matrix_mul_scal_exp<src_exp,Sb>& src, const matrix_mul_scal_exp<src_exp,Sb>& src,
typename src_exp::type alpha, typename src_exp::type alpha,
bool add_to bool add_to
...@@ -451,11 +485,11 @@ namespace dlib ...@@ -451,11 +485,11 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template < template <
typename T, long NR, long NC, typename MM, typename L, typename dest_exp,
typename src_exp, typename src_exp2 typename src_exp, typename src_exp2
> >
void matrix_assign_blas_proxy ( void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest, dest_exp& dest,
const matrix_subtract_exp<src_exp, src_exp2>& src, const matrix_subtract_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha, typename src_exp::type alpha,
bool add_to bool add_to
...@@ -500,6 +534,75 @@ namespace dlib ...@@ -500,6 +534,75 @@ namespace dlib
} }
} }
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
assignable_sub_matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
)
{
if (src.aliases(dest.m))
{
matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc());
matrix_assign_blas_proxy(temp,src,1,false);
matrix_assign_default(dest,temp);
}
else
{
matrix_assign_blas_proxy(dest,src,1,false);
}
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
assignable_row_matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
)
{
if (src.aliases(dest.m))
{
matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc());
matrix_assign_blas_proxy(temp,src,1,false);
matrix_assign_default(dest,temp);
}
else
{
matrix_assign_blas_proxy(dest,src,1,false);
}
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
assignable_col_matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
)
{
if (src.aliases(dest.m))
{
matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc());
matrix_assign_blas_proxy(temp,src,1,false);
matrix_assign_default(dest,temp);
}
else
{
matrix_assign_blas_proxy(dest,src,1,false);
}
}
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template < template <
...@@ -613,6 +716,63 @@ namespace dlib ...@@ -613,6 +716,63 @@ namespace dlib
blas_bindings::matrix_assign_blas(dest,src); blas_bindings::matrix_assign_blas(dest,src);
} }
// ----------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
inline typename enable_if_c<(is_same_type<T,float>::value ||
is_same_type<T,double>::value ||
is_same_type<T,std::complex<float> >::value ||
is_same_type<T,std::complex<double> >::value) &&
blas_bindings::has_matrix_multiply<src_exp>::value
>::type matrix_assign_big (
assignable_sub_matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
)
{
blas_bindings::matrix_assign_blas(dest,src);
}
// ----------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
inline typename enable_if_c<(is_same_type<T,float>::value ||
is_same_type<T,double>::value ||
is_same_type<T,std::complex<float> >::value ||
is_same_type<T,std::complex<double> >::value) &&
blas_bindings::has_matrix_multiply<src_exp>::value
>::type matrix_assign_big (
assignable_row_matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
)
{
blas_bindings::matrix_assign_blas(dest,src);
}
// ----------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
inline typename enable_if_c<(is_same_type<T,float>::value ||
is_same_type<T,double>::value ||
is_same_type<T,std::complex<float> >::value ||
is_same_type<T,std::complex<double> >::value) &&
blas_bindings::has_matrix_multiply<src_exp>::value
>::type matrix_assign_big (
assignable_col_matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
)
{
blas_bindings::matrix_assign_blas(dest,src);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -213,6 +213,11 @@ namespace dlib ...@@ -213,6 +213,11 @@ namespace dlib
template <typename T, long NR, long NC, typename MM> template <typename T, long NR, long NC, typename MM>
int get_ld (const matrix_sub_exp<matrix<T,NR,NC,MM,column_major_layout> >& m) { return m.m.nr(); } int get_ld (const matrix_sub_exp<matrix<T,NR,NC,MM,column_major_layout> >& m) { return m.m.nr(); }
template <typename T, long NR, long NC, typename MM>
int get_ld (const assignable_sub_matrix<T,NR,NC,MM,row_major_layout>& m) { return m.m.nc(); }
template <typename T, long NR, long NC, typename MM>
int get_ld (const assignable_sub_matrix<T,NR,NC,MM,column_major_layout>& m) { return m.m.nr(); }
// -------- // --------
...@@ -243,6 +248,31 @@ namespace dlib ...@@ -243,6 +248,31 @@ namespace dlib
return m.m.nr(); return m.m.nr();
} }
template <typename T, long NR, long NC, typename MM>
int get_inc(const assignable_row_matrix<T,NR,NC,MM,row_major_layout>& m)
{
return 1;
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const assignable_row_matrix<T,NR,NC,MM,column_major_layout>& m)
{
return m.m.nr();
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const assignable_col_matrix<T,NR,NC,MM,row_major_layout>& m)
{
return m.m.nc();
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const assignable_col_matrix<T,NR,NC,MM,column_major_layout>& m)
{
return 1;
}
// -------- // --------
template <typename T, long NR, long NC, typename MM, typename L> template <typename T, long NR, long NC, typename MM, typename L>
...@@ -260,6 +290,16 @@ namespace dlib ...@@ -260,6 +290,16 @@ namespace dlib
template <typename T, long NR, long NC, typename MM, typename L> template <typename T, long NR, long NC, typename MM, typename L>
const T* get_ptr (const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,L>,long,op_rowm>& m) { return &m.m(m.s,0); } const T* get_ptr (const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,L>,long,op_rowm>& m) { return &m.m(m.s,0); }
template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (assignable_col_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (assignable_row_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (assignable_sub_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -282,6 +322,7 @@ namespace dlib ...@@ -282,6 +322,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(m*m) DLIB_ADD_BLAS_BINDING(m*m)
{ {
//cout << "BLAS: m*m" << endl;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans; const CBLAS_TRANSPOSE TransB = CblasNoTrans;
......
...@@ -354,6 +354,10 @@ namespace dlib ...@@ -354,6 +354,10 @@ namespace dlib
class assignable_sub_matrix class assignable_sub_matrix
{ {
public: public:
typedef T type;
typedef l layout_type;
typedef matrix<T,NR,NC,mm,l> matrix_type;
assignable_sub_matrix( assignable_sub_matrix(
matrix<T,NR,NC,mm,l>& m_, matrix<T,NR,NC,mm,l>& m_,
const rectangle& rect_ const rectangle& rect_
...@@ -367,6 +371,14 @@ namespace dlib ...@@ -367,6 +371,14 @@ namespace dlib
return m(r+rect.top(),c+rect.left()); return m(r+rect.top(),c+rect.left());
} }
const T& operator() (
long r,
long c
) const
{
return m(r+rect.top(),c+rect.left());
}
long nr() const { return rect.height(); } long nr() const { return rect.height(); }
long nc() const { return rect.width(); } long nc() const { return rect.width(); }
...@@ -413,7 +425,6 @@ namespace dlib ...@@ -413,7 +425,6 @@ namespace dlib
return *this; return *this;
} }
private:
matrix<T,NR,NC,mm,l>& m; matrix<T,NR,NC,mm,l>& m;
const rectangle rect; const rectangle rect;
...@@ -615,6 +626,10 @@ namespace dlib ...@@ -615,6 +626,10 @@ namespace dlib
class assignable_col_matrix class assignable_col_matrix
{ {
public: public:
typedef T type;
typedef l layout_type;
typedef matrix<T,NR,NC,mm,l> matrix_type;
assignable_col_matrix( assignable_col_matrix(
matrix<T,NR,NC,mm,l>& m_, matrix<T,NR,NC,mm,l>& m_,
const long col_ const long col_
...@@ -628,6 +643,14 @@ namespace dlib ...@@ -628,6 +643,14 @@ namespace dlib
return m(r,col); return m(r,col);
} }
const T& operator() (
long r,
long c
) const
{
return m(r,col);
}
long nr() const { return m.nr(); } long nr() const { return m.nr(); }
long nc() const { return 1; } long nc() const { return 1; }
...@@ -670,7 +693,6 @@ namespace dlib ...@@ -670,7 +693,6 @@ namespace dlib
return *this; return *this;
} }
private:
matrix<T,NR,NC,mm,l>& m; matrix<T,NR,NC,mm,l>& m;
const long col; const long col;
...@@ -702,6 +724,10 @@ namespace dlib ...@@ -702,6 +724,10 @@ namespace dlib
class assignable_row_matrix class assignable_row_matrix
{ {
public: public:
typedef T type;
typedef l layout_type;
typedef matrix<T,NR,NC,mm,l> matrix_type;
assignable_row_matrix( assignable_row_matrix(
matrix<T,NR,NC,mm,l>& m_, matrix<T,NR,NC,mm,l>& m_,
const long row_ const long row_
...@@ -716,6 +742,14 @@ namespace dlib ...@@ -716,6 +742,14 @@ namespace dlib
return m(row,c); return m(row,c);
} }
const T& operator() (
long r,
long c
) const
{
return m(row,c);
}
long nr() const { return 1; } long nr() const { return 1; }
long nc() const { return m.nc(); } long nc() const { return m.nc(); }
...@@ -759,7 +793,6 @@ namespace dlib ...@@ -759,7 +793,6 @@ namespace dlib
return *this; return *this;
} }
private:
matrix<T,NR,NC,mm,l>& m; matrix<T,NR,NC,mm,l>& m;
const long row; const long row;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment