Commit 887f979b authored by Davis King's avatar Davis King

Added stuff to the BLAS bindings so that they can now bind

to statements with matrix sub expressions as the destination
matrix.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402775
parent bfa4b3f2
......@@ -25,6 +25,22 @@ namespace dlib
namespace blas_bindings
{
// ------------------------------------------------------------------------------------
template <typename T>
void zero_matrix (
T& m
)
{
for (long r = 0; r < m.nr(); ++r)
{
for (long c = 0; c < m.nc(); ++c)
{
m(r,c) = 0;
}
}
}
// ------------------------------------------------------------------------------------
// This template struct is used to tell us if a matrix expression contains a matrix multiply.
......@@ -187,7 +203,7 @@ namespace dlib
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename dest_exp,
typename src_exp,
typename enabled = void
>
......@@ -198,7 +214,7 @@ namespace dlib
// let the default matrix assignment happen.
template <typename EXP>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
dest_exp& dest,
const EXP& src,
typename src_exp::type alpha,
bool add_to
......@@ -212,7 +228,7 @@ namespace dlib
// than the above default function would.
template <typename EXP1, typename EXP2>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
dest_exp& dest,
const matrix_multiply_exp<EXP1,EXP2>& src,
typename src_exp::type alpha,
bool add_to
......@@ -229,7 +245,7 @@ namespace dlib
}
else
{
set_all_elements(dest,0);
zero_matrix(dest);
default_matrix_multiply(dest, src.lhs, src.rhs);
}
}
......@@ -237,16 +253,16 @@ namespace dlib
{
if (add_to)
{
matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc());
set_all_elements(temp,0);
typename dest_exp::matrix_type temp(dest.nr(),dest.nc());
zero_matrix(temp);
default_matrix_multiply(temp, src.lhs, src.rhs);
dest += alpha*temp;
matrix_assign_default(dest,temp, alpha,true);
}
else
{
set_all_elements(dest,0);
zero_matrix(dest);
default_matrix_multiply(dest, src.lhs, src.rhs);
dest = alpha*dest;
matrix_assign_default(dest,dest, alpha, false);
}
}
}
......@@ -254,19 +270,21 @@ namespace dlib
// This is a macro to help us add overloads for the matrix_assign_blas_helper template.
// Using this macro it is easy to add overloads for arbitrary matrix expressions.
#define DLIB_ADD_BLAS_BINDING(src_expression) \
template <typename T, typename L> struct BOOST_JOIN(blas,__LINE__) \
{ const static bool value = sizeof(yes_type) == sizeof(test<T,L>(src_expression)); }; \
template < typename T, long NR, long NC, typename MM, typename L, typename src_exp >\
struct matrix_assign_blas_helper<T,NR,NC,MM,L, src_exp, \
typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp,L> >::type > { \
static void assign ( \
matrix<T,NR,NC,MM,L>& dest, \
const src_exp& src, \
typename src_exp::type alpha, \
bool add_to \
) { \
const bool is_row_major_order = is_same_type<L,row_major_layout>::value;
#define DLIB_ADD_BLAS_BINDING(src_expression) \
template <typename T, typename L> struct BOOST_JOIN(blas,__LINE__) \
{ const static bool value = sizeof(yes_type) == sizeof(test<T,L>(src_expression)); }; \
\
template < typename dest_exp, typename src_exp > \
struct matrix_assign_blas_helper<dest_exp, src_exp, \
typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp,typename dest_exp::layout_type> >::type > { \
static void assign ( \
dest_exp& dest, \
const src_exp& src, \
typename src_exp::type alpha, \
bool add_to \
) { \
typedef typename dest_exp::type T; \
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
#define DLIB_END_BLAS_BINDING }};
......@@ -277,11 +295,11 @@ namespace dlib
// ------------------- Forward Declarations -------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename dest_exp,
typename src_exp
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
dest_exp& dest,
const src_exp& src,
typename src_exp::type alpha,
bool add_to
......@@ -294,11 +312,11 @@ namespace dlib
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename dest_exp,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
dest_exp& dest,
const matrix_add_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
......@@ -311,11 +329,11 @@ namespace dlib
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename dest_exp,
typename src_exp, bool Sb
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
dest_exp& dest,
const matrix_mul_scal_exp<src_exp,Sb>& src,
typename src_exp::type alpha,
bool add_to
......@@ -328,11 +346,11 @@ namespace dlib
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename dest_exp,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
dest_exp& dest,
const matrix_subtract_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
......@@ -370,6 +388,22 @@ namespace dlib
This is an important case to catch because it is the expression used
to represent the += matrix operator.
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<src_exp, matrix<T,NR,NC,MM,L> >& src
);
/*!
This function catches the expressions of the form:
M = exp + M;
and converts them into the appropriate matrix_assign_blas() call.
This is an important case to catch because it is the expression used
to represent the += matrix operator.
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
......@@ -395,27 +429,27 @@ namespace dlib
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename dest_exp,
typename src_exp
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
dest_exp& dest,
const src_exp& src,
typename src_exp::type alpha,
bool add_to
)
{
matrix_assign_blas_helper<T,NR,NC,MM,L,src_exp>::assign(dest,src,alpha,add_to);
matrix_assign_blas_helper<dest_exp,src_exp>::assign(dest,src,alpha,add_to);
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename dest_exp,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
dest_exp& dest,
const matrix_add_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
......@@ -435,11 +469,11 @@ namespace dlib
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename dest_exp,
typename src_exp, bool Sb
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
dest_exp& dest,
const matrix_mul_scal_exp<src_exp,Sb>& src,
typename src_exp::type alpha,
bool add_to
......@@ -451,11 +485,11 @@ namespace dlib
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename dest_exp,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
dest_exp& dest,
const matrix_subtract_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
......@@ -500,6 +534,75 @@ namespace dlib
}
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
assignable_sub_matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
)
{
if (src.aliases(dest.m))
{
matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc());
matrix_assign_blas_proxy(temp,src,1,false);
matrix_assign_default(dest,temp);
}
else
{
matrix_assign_blas_proxy(dest,src,1,false);
}
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
assignable_row_matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
)
{
if (src.aliases(dest.m))
{
matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc());
matrix_assign_blas_proxy(temp,src,1,false);
matrix_assign_default(dest,temp);
}
else
{
matrix_assign_blas_proxy(dest,src,1,false);
}
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
assignable_col_matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
)
{
if (src.aliases(dest.m))
{
matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc());
matrix_assign_blas_proxy(temp,src,1,false);
matrix_assign_default(dest,temp);
}
else
{
matrix_assign_blas_proxy(dest,src,1,false);
}
}
// ------------------------------------------------------------------------------------
template <
......@@ -613,6 +716,63 @@ namespace dlib
blas_bindings::matrix_assign_blas(dest,src);
}
// ----------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
inline typename enable_if_c<(is_same_type<T,float>::value ||
is_same_type<T,double>::value ||
is_same_type<T,std::complex<float> >::value ||
is_same_type<T,std::complex<double> >::value) &&
blas_bindings::has_matrix_multiply<src_exp>::value
>::type matrix_assign_big (
assignable_sub_matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
)
{
blas_bindings::matrix_assign_blas(dest,src);
}
// ----------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
inline typename enable_if_c<(is_same_type<T,float>::value ||
is_same_type<T,double>::value ||
is_same_type<T,std::complex<float> >::value ||
is_same_type<T,std::complex<double> >::value) &&
blas_bindings::has_matrix_multiply<src_exp>::value
>::type matrix_assign_big (
assignable_row_matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
)
{
blas_bindings::matrix_assign_blas(dest,src);
}
// ----------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
inline typename enable_if_c<(is_same_type<T,float>::value ||
is_same_type<T,double>::value ||
is_same_type<T,std::complex<float> >::value ||
is_same_type<T,std::complex<double> >::value) &&
blas_bindings::has_matrix_multiply<src_exp>::value
>::type matrix_assign_big (
assignable_col_matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
)
{
blas_bindings::matrix_assign_blas(dest,src);
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......
......@@ -213,6 +213,11 @@ namespace dlib
template <typename T, long NR, long NC, typename MM>
int get_ld (const matrix_sub_exp<matrix<T,NR,NC,MM,column_major_layout> >& m) { return m.m.nr(); }
template <typename T, long NR, long NC, typename MM>
int get_ld (const assignable_sub_matrix<T,NR,NC,MM,row_major_layout>& m) { return m.m.nc(); }
template <typename T, long NR, long NC, typename MM>
int get_ld (const assignable_sub_matrix<T,NR,NC,MM,column_major_layout>& m) { return m.m.nr(); }
// --------
......@@ -243,6 +248,31 @@ namespace dlib
return m.m.nr();
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const assignable_row_matrix<T,NR,NC,MM,row_major_layout>& m)
{
return 1;
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const assignable_row_matrix<T,NR,NC,MM,column_major_layout>& m)
{
return m.m.nr();
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const assignable_col_matrix<T,NR,NC,MM,row_major_layout>& m)
{
return m.m.nc();
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const assignable_col_matrix<T,NR,NC,MM,column_major_layout>& m)
{
return 1;
}
// --------
template <typename T, long NR, long NC, typename MM, typename L>
......@@ -260,6 +290,16 @@ namespace dlib
template <typename T, long NR, long NC, typename MM, typename L>
const T* get_ptr (const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,L>,long,op_rowm>& m) { return &m.m(m.s,0); }
template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (assignable_col_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (assignable_row_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (assignable_sub_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......@@ -282,6 +322,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(m*m)
{
//cout << "BLAS: m*m" << endl;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
......
......@@ -354,6 +354,10 @@ namespace dlib
class assignable_sub_matrix
{
public:
typedef T type;
typedef l layout_type;
typedef matrix<T,NR,NC,mm,l> matrix_type;
assignable_sub_matrix(
matrix<T,NR,NC,mm,l>& m_,
const rectangle& rect_
......@@ -367,6 +371,14 @@ namespace dlib
return m(r+rect.top(),c+rect.left());
}
const T& operator() (
long r,
long c
) const
{
return m(r+rect.top(),c+rect.left());
}
long nr() const { return rect.height(); }
long nc() const { return rect.width(); }
......@@ -413,7 +425,6 @@ namespace dlib
return *this;
}
private:
matrix<T,NR,NC,mm,l>& m;
const rectangle rect;
......@@ -615,6 +626,10 @@ namespace dlib
class assignable_col_matrix
{
public:
typedef T type;
typedef l layout_type;
typedef matrix<T,NR,NC,mm,l> matrix_type;
assignable_col_matrix(
matrix<T,NR,NC,mm,l>& m_,
const long col_
......@@ -628,6 +643,14 @@ namespace dlib
return m(r,col);
}
const T& operator() (
long r,
long c
) const
{
return m(r,col);
}
long nr() const { return m.nr(); }
long nc() const { return 1; }
......@@ -670,7 +693,6 @@ namespace dlib
return *this;
}
private:
matrix<T,NR,NC,mm,l>& m;
const long col;
......@@ -702,6 +724,10 @@ namespace dlib
class assignable_row_matrix
{
public:
typedef T type;
typedef l layout_type;
typedef matrix<T,NR,NC,mm,l> matrix_type;
assignable_row_matrix(
matrix<T,NR,NC,mm,l>& m_,
const long row_
......@@ -716,6 +742,14 @@ namespace dlib
return m(row,c);
}
const T& operator() (
long r,
long c
) const
{
return m(row,c);
}
long nr() const { return 1; }
long nc() const { return m.nc(); }
......@@ -759,7 +793,6 @@ namespace dlib
return *this;
}
private:
matrix<T,NR,NC,mm,l>& m;
const long row;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment