Commit c596cebb authored by Davis King's avatar Davis King

Improved the BLAS binding system. It should now break expressions up into

their proper basic BLAS function calls.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402764
parent 6b5f9a8a
......@@ -24,6 +24,39 @@ namespace dlib
namespace blas_bindings
{
// ------------------------------------------------------------------------------------
// This template struct is used to tell us if a matrix expression contains a matrix multiply.
template <typename T>
struct has_matrix_multiply
{
const static bool value = false;
};
template <typename T, typename U>
struct has_matrix_multiply<matrix_multiply_exp<T,U> >
{ const static bool value = true; };
template <typename T, typename U>
struct has_matrix_multiply<matrix_add_exp<T,U> >
{ const static bool value = has_matrix_multiply<T>::value || has_matrix_multiply<U>::value; };
template <typename T, typename U>
struct has_matrix_multiply<matrix_subtract_exp<T,U> >
{ const static bool value = has_matrix_multiply<T>::value || has_matrix_multiply<U>::value; };
template <typename T, bool Tb>
struct has_matrix_multiply<matrix_mul_scal_exp<T,Tb> >
{ const static bool value = has_matrix_multiply<T>::value; };
template <typename T>
struct has_matrix_multiply<matrix_div_scal_exp<T> >
{ const static bool value = has_matrix_multiply<T>::value; };
template <typename T, typename OP>
struct has_matrix_multiply<matrix_unary_exp<T,OP> >
{ const static bool value = has_matrix_multiply<T>::value; };
// ------------------------------------------------------------------------------------
template <typename T, typename U>
......@@ -110,10 +143,12 @@ namespace dlib
template <typename EXP>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const EXP& src
const EXP& src,
typename src_exp::type alpha,
bool add_to
)
{
matrix_assign_default(dest,src);
matrix_assign_default(dest,src,alpha,add_to);
}
// If we know this is a matrix multiply then apply the
......@@ -122,99 +157,356 @@ namespace dlib
template <typename EXP1, typename EXP2>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const matrix_multiply_exp<EXP1,EXP2>& src
const matrix_multiply_exp<EXP1,EXP2>& src,
typename src_exp::type alpha,
bool add_to
)
{
set_all_elements(dest,0);
default_matrix_multiply(dest, src.lhs, src.rhs);
}
// At some point I need to improve the default (i.e. non BLAS) matrix
// multiplication algorithm...
template <typename EXP1, typename EXP2>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L>, matrix_multiply_exp<EXP1,EXP2> >& src
)
{
if (&dest == &src.lhs)
if (alpha == 1)
{
default_matrix_multiply(dest, src.rhs.lhs, src.rhs.rhs);
if (add_to)
{
default_matrix_multiply(dest, src.lhs, src.rhs);
}
else
{
set_all_elements(dest,0);
default_matrix_multiply(dest, src.lhs, src.rhs);
}
}
else
{
dest = src.lhs;
default_matrix_multiply(dest, src.rhs.lhs, src.rhs.rhs);
if (add_to)
{
matrix<T,NR,NC,MM,L> temp(dest);
default_matrix_multiply(temp, src.lhs, src.rhs);
dest = alpha*temp;
}
else
{
set_all_elements(dest,0);
default_matrix_multiply(dest, src.lhs, src.rhs);
dest = alpha*dest;
}
}
}
};
template <typename EXP1, typename EXP2>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L>, matrix_add_exp<EXP1,EXP2> >& src
)
// This is a macro to help us add overloads for the matrix_assign_blas_helper template.
// Using this macro it is easy to add overloads for arbitrary matrix expressions.
#define DLIB_ADD_BLAS_BINDING( dest_layout, src_expression) \
template <typename T> struct BOOST_JOIN(blas,__LINE__) \
{ const static bool value = sizeof(yes_type) == sizeof(test<T>(src_expression)); }; \
template < typename T, long NR, long NC, typename MM, typename src_exp > \
struct matrix_assign_blas_helper<T,NR,NC,MM,dest_layout, src_exp, \
typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp> >::type > { \
static void assign ( \
matrix<T,NR,NC,MM,dest_layout>& dest, \
const src_exp& src, \
typename src_exp::type alpha, \
bool add_to \
) {
#define DLIB_END_BLAS_BINDING }};
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------- Forward Declarations -------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const src_exp& src,
typename src_exp::type alpha,
bool add_to
);
/*!
requires
- src.aliases(dest) == false
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
);
/*!
requires
- src.aliases(dest) == false
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, bool Sb
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_mul_scal_exp<src_exp,Sb>& src,
typename src_exp::type alpha,
bool add_to
);
/*!
requires
- src.aliases(dest) == false
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_subtract_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
);
/*!
requires
- src.aliases(dest) == false
!*/
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
);
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
);
/*!
This function catches the expressions of the form:
M = M + exp;
and converts them into the appropriate matrix_assign_blas() call.
This is an important case to catch because it is the expression used
to represent the += matrix operator.
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const matrix_subtract_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
);
/*!
This function catches the expressions of the form:
M = M - exp;
and converts them into the appropriate matrix_assign_blas() call.
This is an important case to catch because it is the expression used
to represent the -= matrix operator.
!*/
// End of forward declarations for overloaded matrix_assign_blas functions
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const src_exp& src,
typename src_exp::type alpha,
bool add_to
)
{
matrix_assign_blas_helper<T,NR,NC,MM,L,src_exp>::assign(dest,src,alpha,add_to);
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
)
{
if (src_exp::cost > 9 || src_exp2::cost > 9)
{
if (EXP1::cost > 50 || EXP2::cost > 5)
{
matrix_assign(dest, src.lhs + src.rhs.lhs);
matrix_assign(dest, src.lhs + src.rhs.rhs);
}
else
{
matrix_assign_default(dest,src);
}
matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to);
matrix_assign_blas_proxy(dest, src.rhs, alpha, true);
}
else
{
matrix_assign_default(dest, src, alpha, add_to);
}
}
// ------------------------------------------------------------------------------------
template <typename EXP2>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L>,EXP2>& src
)
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, bool Sb
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_mul_scal_exp<src_exp,Sb>& src,
typename src_exp::type alpha,
bool add_to
)
{
matrix_assign_blas_proxy(dest, src.m, alpha*src.s, add_to);
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_subtract_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
)
{
if (src_exp::cost > 9 || src_exp2::cost > 9)
{
matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to);
matrix_assign_blas_proxy(dest, src.rhs, -alpha, true);
}
else
{
matrix_assign_default(dest, src, alpha, add_to);
}
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// Once we get into this function it means that we are dealing with a matrix of float,
// double, complex<float>, or complex<double> and the src_exp contains at least one
// matrix multiply.
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
)
{
if (src.aliases(dest))
{
matrix<T,NR,NC,MM,L> temp;
matrix_assign_blas_proxy(temp,src,1,false);
temp.swap(dest);
}
else
{
matrix_assign_blas_proxy(dest,src,1,false);
}
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
)
{
if (src_exp::cost > 5)
{
if (EXP2::cost > 50 && &dest != &src.lhs)
if (src.rhs.aliases(dest) == false)
{
dest = src.lhs;
matrix_assign(dest, dest + src.rhs);
if (&src.lhs != &dest)
{
dest = src.lhs;
}
matrix_assign_blas_proxy(dest, src.rhs, 1, true);
}
else
{
matrix_assign_default(dest,src);
matrix<T,NR,NC,MM,L> temp(src.lhs);
matrix_assign_blas_proxy(temp, src.rhs, 1, true);
temp.swap(dest);
}
}
else
{
matrix_assign_default(dest,src);
}
}
// ------------------------------------------------------------------------------------
template <typename EXP1, typename EXP2>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<EXP1,EXP2>& src
)
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const matrix_subtract_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
)
{
if (src_exp::cost > 5)
{
if (EXP1::cost > 50 || EXP2::cost > 50)
if (src.rhs.aliases(dest) == false)
{
matrix_assign(dest,src.lhs);
matrix_assign(dest, dest + src.rhs);
if (&src.lhs != &dest)
{
dest = src.lhs;
}
matrix_assign_blas_proxy(dest, src.rhs, -1, true);
}
else
{
matrix_assign_default(dest,src);
matrix<T,NR,NC,MM,L> temp(src.lhs);
matrix_assign_blas_proxy(temp, src.rhs, -1, true);
temp.swap(dest);
}
}
};
// This is a macro to help us add overloads for the matrix_assign_blas_helper template.
// Using this macro it is easy to add overloads for arbitrary matrix expressions.
#define DLIB_ADD_BLAS_BINDING( dest_layout, src_expression) \
template <typename T> struct BOOST_JOIN(blas,__LINE__) \
{ const static bool value = sizeof(yes_type) == sizeof(test<T>(src_expression)); }; \
template < typename T, long NR, long NC, typename MM, typename src_exp > \
struct matrix_assign_blas_helper<T,NR,NC,MM,dest_layout, src_exp, \
typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp> >::type > { \
static void assign ( \
matrix<T,NR,NC,MM,dest_layout>& dest, \
const src_exp& src \
) {
#define DLIB_END_BLAS_BINDING }};
else
{
matrix_assign_default(dest,src);
}
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
} // end of namespace blas_bindings
......@@ -227,13 +519,14 @@ namespace dlib
inline typename enable_if_c<(is_same_type<T,float>::value ||
is_same_type<T,double>::value ||
is_same_type<T,std::complex<float> >::value ||
is_same_type<T,std::complex<double> >::value)
is_same_type<T,std::complex<double> >::value) &&
blas_bindings::has_matrix_multiply<src_exp>::value
>::type matrix_assign_big (
matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
)
{
blas_bindings::matrix_assign_blas_helper<T,NR,NC,MM,L,src_exp>::assign(dest,src);
blas_bindings::matrix_assign_blas(dest,src);
}
// ----------------------------------------------------------------------------------------
......
......@@ -44,6 +44,12 @@ namespace dlib
EXP1& dest,
const EXP2& src
)
/*!
requires
- src.destructively_aliases(dest) == false
ensures
- #dest == src
!*/
{
for (long r = 0; r < src.nr(); ++r)
{
......@@ -54,6 +60,83 @@ namespace dlib
}
}
// ----------------------------------------------------------------------------------------
template <typename EXP1, typename EXP2>
inline static void matrix_assign_default (
EXP1& dest,
const EXP2& src,
typename EXP2::type alpha,
bool add_to
)
/*!
requires
- src.destructively_aliases(dest) == false
ensures
- if (add_to == false) then
- #dest == alpha*src
- else
- #dest == dest + alpha*src
!*/
{
if (add_to)
{
if (alpha == 1)
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) += src(r,c);
}
}
}
else if (alpha == -1)
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) -= src(r,c);
}
}
}
else
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) += alpha*src(r,c);
}
}
}
}
else
{
if (alpha == 1)
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) = src(r,c);
}
}
}
else
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) = alpha*src(r,c);
}
}
}
}
}
// ----------------------------------------------------------------------------------------
template <
......@@ -83,7 +166,6 @@ namespace dlib
- src.destructively_aliases(dest) == false
ensures
- #dest == src
- the part of dest outside the above sub matrix remains unchanged
!*/
{
// Call src.ref() here so that the derived type of the matrix_exp shows
......@@ -107,7 +189,6 @@ namespace dlib
- src.destructively_aliases(dest) == false
ensures
- #dest == src
- the part of dest outside the above sub matrix remains unchanged
!*/
{
matrix_assign_default(dest,src.ref());
......
......@@ -9,6 +9,8 @@
#include "cblas.h"
#endif
#include <iostream>
namespace dlib
{
......@@ -219,45 +221,18 @@ namespace dlib
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const T alpha = 1;
const T* A = &src.lhs(0,0);
const int lda = src.lhs.nc();
const T* B = &src.rhs(0,0);
const int ldb = src.rhs.nc();
const T beta = 0;
const T beta = add_to?1:0;
T* C = &dest(0,0);
const int ldc = src.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout,rm + rm*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const T alpha = 1;
const T* A = &src.rhs.lhs(0,0);
const int lda = src.rhs.lhs.nc();
const T* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const T beta = 1;
T* C = &dest(0,0);
const int ldc = src.rhs.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, trans(rm)*rm)
......@@ -268,94 +243,18 @@ namespace dlib
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const T alpha = 1;
const T* A = &src.lhs.m(0,0);
const int lda = src.lhs.m.nc();
const T* B = &src.rhs(0,0);
const int ldb = src.rhs.nc();
const T beta = 0;
const T beta = add_to?1:0;
T* C = &dest(0,0);
const int ldc = src.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout, rm + s*trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc());
const int K = static_cast<int>(src.rhs.m.lhs.nc());
const T alpha = src.rhs.s;
const T* A = &src.rhs.m.lhs.m(0,0);
const int lda = src.rhs.m.lhs.m.nc();
const T* B = &src.rhs.m.rhs(0,0);
const int ldb = src.rhs.m.rhs.nc();
const T beta = 1;
T* C = &dest(0,0);
const int ldc = dest.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout, s*trans(rm)*rm)
{
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.m.nr());
const int N = static_cast<int>(src.m.nc());
const int K = static_cast<int>(src.m.lhs.nc());
const T alpha = src.s;
const T* A = &src.m.lhs.m(0,0);
const int lda = src.m.lhs.m.nc();
const T* B = &src.m.rhs(0,0);
const int ldb = src.m.rhs.nc();
const T beta = 0;
T* C = &dest(0,0);
const int ldc = dest.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout, rm + trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const T alpha = 1;
const T* A = &src.rhs.lhs.m(0,0);
const int lda = src.rhs.lhs.m.nc();
const T* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const T beta = 1;
T* C = &dest(0,0);
const int ldc = src.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
#endif // DLIB_USE_BLAS
}
......
......@@ -23,24 +23,25 @@ namespace dlib
// ------------------------------------------------------------------------------------
template <
typename matrix_dest_type,
typename EXP1,
typename EXP2
>
typename enable_if_c<ma::matrix_is_vector<EXP1>::value == true && ma::matrix_is_vector<EXP2>::value == true>::type
default_matrix_multiply (
matrix_dest_type& dest,
const EXP1& lhs,
const EXP2& rhs
);
/*!
requires
- (lhs*rhs).destructively_aliases(dest) == false
- dest.nr() == (lhs*rhs).nr()
- dest.nc() == (lhs*rhs).nc()
ensures
- #dest == dest + lhs*rhs
/*! This file defines the default_matrix_multiply() function. It is a function
that conforms to the following definition:
template <
typename matrix_dest_type,
typename EXP1,
typename EXP2
>
void default_matrix_multiply (
matrix_dest_type& dest,
const EXP1& lhs,
const EXP2& rhs
);
requires
- (lhs*rhs).destructively_aliases(dest) == false
- dest.nr() == (lhs*rhs).nr()
- dest.nc() == (lhs*rhs).nc()
ensures
- #dest == dest + lhs*rhs
!*/
// ------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment