Commit c596cebb authored by Davis King's avatar Davis King

Improved the BLAS binding system. It should now break expressions up into

their proper basic BLAS function calls.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402764
parent 6b5f9a8a
...@@ -24,6 +24,39 @@ namespace dlib ...@@ -24,6 +24,39 @@ namespace dlib
namespace blas_bindings namespace blas_bindings
{ {
// ------------------------------------------------------------------------------------
// This template struct is used to tell us if a matrix expression contains a matrix multiply.
template <typename T>
struct has_matrix_multiply
{
const static bool value = false;
};
template <typename T, typename U>
struct has_matrix_multiply<matrix_multiply_exp<T,U> >
{ const static bool value = true; };
template <typename T, typename U>
struct has_matrix_multiply<matrix_add_exp<T,U> >
{ const static bool value = has_matrix_multiply<T>::value || has_matrix_multiply<U>::value; };
template <typename T, typename U>
struct has_matrix_multiply<matrix_subtract_exp<T,U> >
{ const static bool value = has_matrix_multiply<T>::value || has_matrix_multiply<U>::value; };
template <typename T, bool Tb>
struct has_matrix_multiply<matrix_mul_scal_exp<T,Tb> >
{ const static bool value = has_matrix_multiply<T>::value; };
template <typename T>
struct has_matrix_multiply<matrix_div_scal_exp<T> >
{ const static bool value = has_matrix_multiply<T>::value; };
template <typename T, typename OP>
struct has_matrix_multiply<matrix_unary_exp<T,OP> >
{ const static bool value = has_matrix_multiply<T>::value; };
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template <typename T, typename U> template <typename T, typename U>
...@@ -110,10 +143,12 @@ namespace dlib ...@@ -110,10 +143,12 @@ namespace dlib
template <typename EXP> template <typename EXP>
static void assign ( static void assign (
matrix<T,NR,NC,MM,L>& dest, matrix<T,NR,NC,MM,L>& dest,
const EXP& src const EXP& src,
typename src_exp::type alpha,
bool add_to
) )
{ {
matrix_assign_default(dest,src); matrix_assign_default(dest,src,alpha,add_to);
} }
// If we know this is a matrix multiply then apply the // If we know this is a matrix multiply then apply the
...@@ -122,99 +157,356 @@ namespace dlib ...@@ -122,99 +157,356 @@ namespace dlib
template <typename EXP1, typename EXP2> template <typename EXP1, typename EXP2>
static void assign ( static void assign (
matrix<T,NR,NC,MM,L>& dest, matrix<T,NR,NC,MM,L>& dest,
const matrix_multiply_exp<EXP1,EXP2>& src const matrix_multiply_exp<EXP1,EXP2>& src,
typename src_exp::type alpha,
bool add_to
) )
{ {
set_all_elements(dest,0); // At some point I need to improve the default (i.e. non BLAS) matrix
default_matrix_multiply(dest, src.lhs, src.rhs); // multiplication algorithm...
}
template <typename EXP1, typename EXP2> if (alpha == 1)
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L>, matrix_multiply_exp<EXP1,EXP2> >& src
)
{
if (&dest == &src.lhs)
{ {
default_matrix_multiply(dest, src.rhs.lhs, src.rhs.rhs); if (add_to)
{
default_matrix_multiply(dest, src.lhs, src.rhs);
}
else
{
set_all_elements(dest,0);
default_matrix_multiply(dest, src.lhs, src.rhs);
}
} }
else else
{ {
dest = src.lhs; if (add_to)
default_matrix_multiply(dest, src.rhs.lhs, src.rhs.rhs); {
matrix<T,NR,NC,MM,L> temp(dest);
default_matrix_multiply(temp, src.lhs, src.rhs);
dest = alpha*temp;
}
else
{
set_all_elements(dest,0);
default_matrix_multiply(dest, src.lhs, src.rhs);
dest = alpha*dest;
}
} }
} }
};
template <typename EXP1, typename EXP2> // This is a macro to help us add overloads for the matrix_assign_blas_helper template.
static void assign ( // Using this macro it is easy to add overloads for arbitrary matrix expressions.
matrix<T,NR,NC,MM,L>& dest, #define DLIB_ADD_BLAS_BINDING( dest_layout, src_expression) \
const matrix_add_exp<matrix<T,NR,NC,MM,L>, matrix_add_exp<EXP1,EXP2> >& src template <typename T> struct BOOST_JOIN(blas,__LINE__) \
) { const static bool value = sizeof(yes_type) == sizeof(test<T>(src_expression)); }; \
template < typename T, long NR, long NC, typename MM, typename src_exp > \
struct matrix_assign_blas_helper<T,NR,NC,MM,dest_layout, src_exp, \
typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp> >::type > { \
static void assign ( \
matrix<T,NR,NC,MM,dest_layout>& dest, \
const src_exp& src, \
typename src_exp::type alpha, \
bool add_to \
) {
#define DLIB_END_BLAS_BINDING }};
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------- Forward Declarations -------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const src_exp& src,
typename src_exp::type alpha,
bool add_to
);
/*!
requires
- src.aliases(dest) == false
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
);
/*!
requires
- src.aliases(dest) == false
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, bool Sb
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_mul_scal_exp<src_exp,Sb>& src,
typename src_exp::type alpha,
bool add_to
);
/*!
requires
- src.aliases(dest) == false
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_subtract_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
);
/*!
requires
- src.aliases(dest) == false
!*/
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
);
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
);
/*!
This function catches the expressions of the form:
M = M + exp;
and converts them into the appropriate matrix_assign_blas() call.
This is an important case to catch because it is the expression used
to represent the += matrix operator.
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const matrix_subtract_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
);
/*!
This function catches the expressions of the form:
M = M - exp;
and converts them into the appropriate matrix_assign_blas() call.
This is an important case to catch because it is the expression used
to represent the -= matrix operator.
!*/
// End of forward declarations for overloaded matrix_assign_blas functions
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const src_exp& src,
typename src_exp::type alpha,
bool add_to
)
{
matrix_assign_blas_helper<T,NR,NC,MM,L,src_exp>::assign(dest,src,alpha,add_to);
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
)
{
if (src_exp::cost > 9 || src_exp2::cost > 9)
{ {
if (EXP1::cost > 50 || EXP2::cost > 5) matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to);
{ matrix_assign_blas_proxy(dest, src.rhs, alpha, true);
matrix_assign(dest, src.lhs + src.rhs.lhs);
matrix_assign(dest, src.lhs + src.rhs.rhs);
}
else
{
matrix_assign_default(dest,src);
}
} }
else
{
matrix_assign_default(dest, src, alpha, add_to);
}
}
// ------------------------------------------------------------------------------------
template <typename EXP2> template <
static void assign ( typename T, long NR, long NC, typename MM, typename L,
matrix<T,NR,NC,MM,L>& dest, typename src_exp, bool Sb
const matrix_add_exp<matrix<T,NR,NC,MM,L>,EXP2>& src >
) void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_mul_scal_exp<src_exp,Sb>& src,
typename src_exp::type alpha,
bool add_to
)
{
matrix_assign_blas_proxy(dest, src.m, alpha*src.s, add_to);
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_subtract_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
)
{
if (src_exp::cost > 9 || src_exp2::cost > 9)
{
matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to);
matrix_assign_blas_proxy(dest, src.rhs, -alpha, true);
}
else
{
matrix_assign_default(dest, src, alpha, add_to);
}
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// Once we get into this function it means that we are dealing with a matrix of float,
// double, complex<float>, or complex<double> and the src_exp contains at least one
// matrix multiply.
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
)
{
if (src.aliases(dest))
{
matrix<T,NR,NC,MM,L> temp;
matrix_assign_blas_proxy(temp,src,1,false);
temp.swap(dest);
}
else
{
matrix_assign_blas_proxy(dest,src,1,false);
}
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
)
{
if (src_exp::cost > 5)
{ {
if (EXP2::cost > 50 && &dest != &src.lhs) if (src.rhs.aliases(dest) == false)
{ {
dest = src.lhs; if (&src.lhs != &dest)
matrix_assign(dest, dest + src.rhs); {
dest = src.lhs;
}
matrix_assign_blas_proxy(dest, src.rhs, 1, true);
} }
else else
{ {
matrix_assign_default(dest,src); matrix<T,NR,NC,MM,L> temp(src.lhs);
matrix_assign_blas_proxy(temp, src.rhs, 1, true);
temp.swap(dest);
} }
} }
else
{
matrix_assign_default(dest,src);
}
}
// ------------------------------------------------------------------------------------
template <
template <typename EXP1, typename EXP2> typename T, long NR, long NC, typename MM, typename L,
static void assign ( typename src_exp
matrix<T,NR,NC,MM,L>& dest, >
const matrix_add_exp<EXP1,EXP2>& src void matrix_assign_blas (
) matrix<T,NR,NC,MM,L>& dest,
const matrix_subtract_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
)
{
if (src_exp::cost > 5)
{ {
if (EXP1::cost > 50 || EXP2::cost > 50) if (src.rhs.aliases(dest) == false)
{ {
matrix_assign(dest,src.lhs); if (&src.lhs != &dest)
matrix_assign(dest, dest + src.rhs); {
dest = src.lhs;
}
matrix_assign_blas_proxy(dest, src.rhs, -1, true);
} }
else else
{ {
matrix_assign_default(dest,src); matrix<T,NR,NC,MM,L> temp(src.lhs);
matrix_assign_blas_proxy(temp, src.rhs, -1, true);
temp.swap(dest);
} }
} }
}; else
{
// This is a macro to help us add overloads for the matrix_assign_blas_helper template. matrix_assign_default(dest,src);
// Using this macro it is easy to add overloads for arbitrary matrix expressions. }
#define DLIB_ADD_BLAS_BINDING( dest_layout, src_expression) \ }
template <typename T> struct BOOST_JOIN(blas,__LINE__) \
{ const static bool value = sizeof(yes_type) == sizeof(test<T>(src_expression)); }; \
template < typename T, long NR, long NC, typename MM, typename src_exp > \
struct matrix_assign_blas_helper<T,NR,NC,MM,dest_layout, src_exp, \
typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp> >::type > { \
static void assign ( \
matrix<T,NR,NC,MM,dest_layout>& dest, \
const src_exp& src \
) {
#define DLIB_END_BLAS_BINDING }};
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
} // end of namespace blas_bindings } // end of namespace blas_bindings
...@@ -227,13 +519,14 @@ namespace dlib ...@@ -227,13 +519,14 @@ namespace dlib
inline typename enable_if_c<(is_same_type<T,float>::value || inline typename enable_if_c<(is_same_type<T,float>::value ||
is_same_type<T,double>::value || is_same_type<T,double>::value ||
is_same_type<T,std::complex<float> >::value || is_same_type<T,std::complex<float> >::value ||
is_same_type<T,std::complex<double> >::value) is_same_type<T,std::complex<double> >::value) &&
blas_bindings::has_matrix_multiply<src_exp>::value
>::type matrix_assign_big ( >::type matrix_assign_big (
matrix<T,NR,NC,MM,L>& dest, matrix<T,NR,NC,MM,L>& dest,
const src_exp& src const src_exp& src
) )
{ {
blas_bindings::matrix_assign_blas_helper<T,NR,NC,MM,L,src_exp>::assign(dest,src); blas_bindings::matrix_assign_blas(dest,src);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -44,6 +44,12 @@ namespace dlib ...@@ -44,6 +44,12 @@ namespace dlib
EXP1& dest, EXP1& dest,
const EXP2& src const EXP2& src
) )
/*!
requires
- src.destructively_aliases(dest) == false
ensures
- #dest == src
!*/
{ {
for (long r = 0; r < src.nr(); ++r) for (long r = 0; r < src.nr(); ++r)
{ {
...@@ -54,6 +60,83 @@ namespace dlib ...@@ -54,6 +60,83 @@ namespace dlib
} }
} }
// ----------------------------------------------------------------------------------------
template <typename EXP1, typename EXP2>
inline static void matrix_assign_default (
EXP1& dest,
const EXP2& src,
typename EXP2::type alpha,
bool add_to
)
/*!
requires
- src.destructively_aliases(dest) == false
ensures
- if (add_to == false) then
- #dest == alpha*src
- else
- #dest == dest + alpha*src
!*/
{
if (add_to)
{
if (alpha == 1)
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) += src(r,c);
}
}
}
else if (alpha == -1)
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) -= src(r,c);
}
}
}
else
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) += alpha*src(r,c);
}
}
}
}
else
{
if (alpha == 1)
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) = src(r,c);
}
}
}
else
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) = alpha*src(r,c);
}
}
}
}
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -83,7 +166,6 @@ namespace dlib ...@@ -83,7 +166,6 @@ namespace dlib
- src.destructively_aliases(dest) == false - src.destructively_aliases(dest) == false
ensures ensures
- #dest == src - #dest == src
- the part of dest outside the above sub matrix remains unchanged
!*/ !*/
{ {
// Call src.ref() here so that the derived type of the matrix_exp shows // Call src.ref() here so that the derived type of the matrix_exp shows
...@@ -107,7 +189,6 @@ namespace dlib ...@@ -107,7 +189,6 @@ namespace dlib
- src.destructively_aliases(dest) == false - src.destructively_aliases(dest) == false
ensures ensures
- #dest == src - #dest == src
- the part of dest outside the above sub matrix remains unchanged
!*/ !*/
{ {
matrix_assign_default(dest,src.ref()); matrix_assign_default(dest,src.ref());
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include "cblas.h" #include "cblas.h"
#endif #endif
#include <iostream>
namespace dlib namespace dlib
{ {
...@@ -219,45 +221,18 @@ namespace dlib ...@@ -219,45 +221,18 @@ namespace dlib
const int M = static_cast<int>(src.nr()); const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc()); const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc()); const int K = static_cast<int>(src.lhs.nc());
const T alpha = 1;
const T* A = &src.lhs(0,0); const T* A = &src.lhs(0,0);
const int lda = src.lhs.nc(); const int lda = src.lhs.nc();
const T* B = &src.rhs(0,0); const T* B = &src.rhs(0,0);
const int ldb = src.rhs.nc(); const int ldb = src.rhs.nc();
const T beta = 0; const T beta = add_to?1:0;
T* C = &dest(0,0); T* C = &dest(0,0);
const int ldc = src.nc(); const int ldc = src.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout,rm + rm*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const T alpha = 1;
const T* A = &src.rhs.lhs(0,0);
const int lda = src.rhs.lhs.nc();
const T* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const T beta = 1;
T* C = &dest(0,0);
const int ldc = src.rhs.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, trans(rm)*rm) DLIB_ADD_BLAS_BINDING(row_major_layout, trans(rm)*rm)
...@@ -268,94 +243,18 @@ namespace dlib ...@@ -268,94 +243,18 @@ namespace dlib
const int M = static_cast<int>(src.nr()); const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc()); const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc()); const int K = static_cast<int>(src.lhs.nc());
const T alpha = 1;
const T* A = &src.lhs.m(0,0); const T* A = &src.lhs.m(0,0);
const int lda = src.lhs.m.nc(); const int lda = src.lhs.m.nc();
const T* B = &src.rhs(0,0); const T* B = &src.rhs(0,0);
const int ldb = src.rhs.nc(); const int ldb = src.rhs.nc();
const T beta = 0; const T beta = add_to?1:0;
T* C = &dest(0,0); T* C = &dest(0,0);
const int ldc = src.nc(); const int ldc = src.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout, rm + s*trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc());
const int K = static_cast<int>(src.rhs.m.lhs.nc());
const T alpha = src.rhs.s;
const T* A = &src.rhs.m.lhs.m(0,0);
const int lda = src.rhs.m.lhs.m.nc();
const T* B = &src.rhs.m.rhs(0,0);
const int ldb = src.rhs.m.rhs.nc();
const T beta = 1;
T* C = &dest(0,0);
const int ldc = dest.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout, s*trans(rm)*rm)
{
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.m.nr());
const int N = static_cast<int>(src.m.nc());
const int K = static_cast<int>(src.m.lhs.nc());
const T alpha = src.s;
const T* A = &src.m.lhs.m(0,0);
const int lda = src.m.lhs.m.nc();
const T* B = &src.m.rhs(0,0);
const int ldb = src.m.rhs.nc();
const T beta = 0;
T* C = &dest(0,0);
const int ldc = dest.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout, rm + trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const T alpha = 1;
const T* A = &src.rhs.lhs.m(0,0);
const int lda = src.rhs.lhs.m.nc();
const T* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const T beta = 1;
T* C = &dest(0,0);
const int ldc = src.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
#endif // DLIB_USE_BLAS #endif // DLIB_USE_BLAS
} }
......
...@@ -23,24 +23,25 @@ namespace dlib ...@@ -23,24 +23,25 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template < /*! This file defines the default_matrix_multiply() function. It is a function
typename matrix_dest_type, that conforms to the following definition:
typename EXP1,
typename EXP2 template <
> typename matrix_dest_type,
typename enable_if_c<ma::matrix_is_vector<EXP1>::value == true && ma::matrix_is_vector<EXP2>::value == true>::type typename EXP1,
default_matrix_multiply ( typename EXP2
matrix_dest_type& dest, >
const EXP1& lhs, void default_matrix_multiply (
const EXP2& rhs matrix_dest_type& dest,
); const EXP1& lhs,
/*! const EXP2& rhs
requires );
- (lhs*rhs).destructively_aliases(dest) == false requires
- dest.nr() == (lhs*rhs).nr() - (lhs*rhs).destructively_aliases(dest) == false
- dest.nc() == (lhs*rhs).nc() - dest.nr() == (lhs*rhs).nr()
ensures - dest.nc() == (lhs*rhs).nc()
- #dest == dest + lhs*rhs ensures
- #dest == dest + lhs*rhs
!*/ !*/
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment