Commit c596cebb authored by Davis King's avatar Davis King

Improved the BLAS binding system. It should now break expressions up into

their proper basic BLAS function calls.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402764
parent 6b5f9a8a
...@@ -24,6 +24,39 @@ namespace dlib ...@@ -24,6 +24,39 @@ namespace dlib
namespace blas_bindings namespace blas_bindings
{ {
// ------------------------------------------------------------------------------------
// This template struct is used to tell us if a matrix expression contains a matrix multiply.
template <typename T>
struct has_matrix_multiply
{
const static bool value = false;
};
template <typename T, typename U>
struct has_matrix_multiply<matrix_multiply_exp<T,U> >
{ const static bool value = true; };
template <typename T, typename U>
struct has_matrix_multiply<matrix_add_exp<T,U> >
{ const static bool value = has_matrix_multiply<T>::value || has_matrix_multiply<U>::value; };
template <typename T, typename U>
struct has_matrix_multiply<matrix_subtract_exp<T,U> >
{ const static bool value = has_matrix_multiply<T>::value || has_matrix_multiply<U>::value; };
template <typename T, bool Tb>
struct has_matrix_multiply<matrix_mul_scal_exp<T,Tb> >
{ const static bool value = has_matrix_multiply<T>::value; };
template <typename T>
struct has_matrix_multiply<matrix_div_scal_exp<T> >
{ const static bool value = has_matrix_multiply<T>::value; };
template <typename T, typename OP>
struct has_matrix_multiply<matrix_unary_exp<T,OP> >
{ const static bool value = has_matrix_multiply<T>::value; };
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template <typename T, typename U> template <typename T, typename U>
...@@ -110,10 +143,12 @@ namespace dlib ...@@ -110,10 +143,12 @@ namespace dlib
template <typename EXP> template <typename EXP>
static void assign ( static void assign (
matrix<T,NR,NC,MM,L>& dest, matrix<T,NR,NC,MM,L>& dest,
const EXP& src const EXP& src,
typename src_exp::type alpha,
bool add_to
) )
{ {
matrix_assign_default(dest,src); matrix_assign_default(dest,src,alpha,add_to);
} }
// If we know this is a matrix multiply then apply the // If we know this is a matrix multiply then apply the
...@@ -122,98 +157,355 @@ namespace dlib ...@@ -122,98 +157,355 @@ namespace dlib
template <typename EXP1, typename EXP2> template <typename EXP1, typename EXP2>
static void assign ( static void assign (
matrix<T,NR,NC,MM,L>& dest, matrix<T,NR,NC,MM,L>& dest,
const matrix_multiply_exp<EXP1,EXP2>& src const matrix_multiply_exp<EXP1,EXP2>& src,
typename src_exp::type alpha,
bool add_to
) )
{
// At some point I need to improve the default (i.e. non BLAS) matrix
// multiplication algorithm...
if (alpha == 1)
{
if (add_to)
{
default_matrix_multiply(dest, src.lhs, src.rhs);
}
else
{ {
set_all_elements(dest,0); set_all_elements(dest,0);
default_matrix_multiply(dest, src.lhs, src.rhs); default_matrix_multiply(dest, src.lhs, src.rhs);
} }
}
else
{
if (add_to)
{
matrix<T,NR,NC,MM,L> temp(dest);
default_matrix_multiply(temp, src.lhs, src.rhs);
dest = alpha*temp;
}
else
{
set_all_elements(dest,0);
default_matrix_multiply(dest, src.lhs, src.rhs);
dest = alpha*dest;
}
}
}
};
template <typename EXP1, typename EXP2> // This is a macro to help us add overloads for the matrix_assign_blas_helper template.
static void assign ( // Using this macro it is easy to add overloads for arbitrary matrix expressions.
#define DLIB_ADD_BLAS_BINDING( dest_layout, src_expression) \
template <typename T> struct BOOST_JOIN(blas,__LINE__) \
{ const static bool value = sizeof(yes_type) == sizeof(test<T>(src_expression)); }; \
template < typename T, long NR, long NC, typename MM, typename src_exp > \
struct matrix_assign_blas_helper<T,NR,NC,MM,dest_layout, src_exp, \
typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp> >::type > { \
static void assign ( \
matrix<T,NR,NC,MM,dest_layout>& dest, \
const src_exp& src, \
typename src_exp::type alpha, \
bool add_to \
) {
#define DLIB_END_BLAS_BINDING }};
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------- Forward Declarations -------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const src_exp& src,
typename src_exp::type alpha,
bool add_to
);
/*!
requires
- src.aliases(dest) == false
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest, matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L>, matrix_multiply_exp<EXP1,EXP2> >& src const matrix_add_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
);
/*!
requires
- src.aliases(dest) == false
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, bool Sb
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_mul_scal_exp<src_exp,Sb>& src,
typename src_exp::type alpha,
bool add_to
);
/*!
requires
- src.aliases(dest) == false
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_subtract_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
);
/*!
requires
- src.aliases(dest) == false
!*/
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
);
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
);
/*!
This function catches the expressions of the form:
M = M + exp;
and converts them into the appropriate matrix_assign_blas() call.
This is an important case to catch because it is the expression used
to represent the += matrix operator.
!*/
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const matrix_subtract_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
);
/*!
This function catches the expressions of the form:
M = M - exp;
and converts them into the appropriate matrix_assign_blas() call.
This is an important case to catch because it is the expression used
to represent the -= matrix operator.
!*/
// End of forward declarations for overloaded matrix_assign_blas functions
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const src_exp& src,
typename src_exp::type alpha,
bool add_to
) )
{ {
if (&dest == &src.lhs) matrix_assign_blas_helper<T,NR,NC,MM,L,src_exp>::assign(dest,src,alpha,add_to);
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
)
{ {
default_matrix_multiply(dest, src.rhs.lhs, src.rhs.rhs); if (src_exp::cost > 9 || src_exp2::cost > 9)
{
matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to);
matrix_assign_blas_proxy(dest, src.rhs, alpha, true);
} }
else else
{ {
dest = src.lhs; matrix_assign_default(dest, src, alpha, add_to);
default_matrix_multiply(dest, src.rhs.lhs, src.rhs.rhs);
} }
} }
template <typename EXP1, typename EXP2> // ------------------------------------------------------------------------------------
static void assign (
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, bool Sb
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest, matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L>, matrix_add_exp<EXP1,EXP2> >& src const matrix_mul_scal_exp<src_exp,Sb>& src,
typename src_exp::type alpha,
bool add_to
) )
{ {
if (EXP1::cost > 50 || EXP2::cost > 5) matrix_assign_blas_proxy(dest, src.m, alpha*src.s, add_to);
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp, typename src_exp2
>
void matrix_assign_blas_proxy (
matrix<T,NR,NC,MM,L>& dest,
const matrix_subtract_exp<src_exp, src_exp2>& src,
typename src_exp::type alpha,
bool add_to
)
{
if (src_exp::cost > 9 || src_exp2::cost > 9)
{ {
matrix_assign(dest, src.lhs + src.rhs.lhs); matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to);
matrix_assign(dest, src.lhs + src.rhs.rhs); matrix_assign_blas_proxy(dest, src.rhs, -alpha, true);
} }
else else
{ {
matrix_assign_default(dest,src); matrix_assign_default(dest, src, alpha, add_to);
} }
} }
template <typename EXP2> // ------------------------------------------------------------------------------------
static void assign ( // ------------------------------------------------------------------------------------
// Once we get into this function it means that we are dealing with a matrix of float,
// double, complex<float>, or complex<double> and the src_exp contains at least one
// matrix multiply.
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest, matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L>,EXP2>& src const src_exp& src
) )
{ {
if (EXP2::cost > 50 && &dest != &src.lhs) if (src.aliases(dest))
{ {
dest = src.lhs; matrix<T,NR,NC,MM,L> temp;
matrix_assign(dest, dest + src.rhs); matrix_assign_blas_proxy(temp,src,1,false);
temp.swap(dest);
} }
else else
{ {
matrix_assign_default(dest,src); matrix_assign_blas_proxy(dest,src,1,false);
} }
} }
// ------------------------------------------------------------------------------------
template <typename EXP1, typename EXP2> template <
static void assign ( typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest, matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<EXP1,EXP2>& src const matrix_add_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
) )
{ {
if (EXP1::cost > 50 || EXP2::cost > 50) if (src_exp::cost > 5)
{ {
matrix_assign(dest,src.lhs); if (src.rhs.aliases(dest) == false)
matrix_assign(dest, dest + src.rhs); {
if (&src.lhs != &dest)
{
dest = src.lhs;
}
matrix_assign_blas_proxy(dest, src.rhs, 1, true);
}
else
{
matrix<T,NR,NC,MM,L> temp(src.lhs);
matrix_assign_blas_proxy(temp, src.rhs, 1, true);
temp.swap(dest);
}
} }
else else
{ {
matrix_assign_default(dest,src); matrix_assign_default(dest,src);
} }
} }
};
// This is a macro to help us add overloads for the matrix_assign_blas_helper template. // ------------------------------------------------------------------------------------
// Using this macro it is easy to add overloads for arbitrary matrix expressions.
#define DLIB_ADD_BLAS_BINDING( dest_layout, src_expression) \
template <typename T> struct BOOST_JOIN(blas,__LINE__) \
{ const static bool value = sizeof(yes_type) == sizeof(test<T>(src_expression)); }; \
template < typename T, long NR, long NC, typename MM, typename src_exp > \
struct matrix_assign_blas_helper<T,NR,NC,MM,dest_layout, src_exp, \
typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp> >::type > { \
static void assign ( \
matrix<T,NR,NC,MM,dest_layout>& dest, \
const src_exp& src \
) {
#define DLIB_END_BLAS_BINDING }}; template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const matrix_subtract_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
)
{
if (src_exp::cost > 5)
{
if (src.rhs.aliases(dest) == false)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
matrix_assign_blas_proxy(dest, src.rhs, -1, true);
}
else
{
matrix<T,NR,NC,MM,L> temp(src.lhs);
matrix_assign_blas_proxy(temp, src.rhs, -1, true);
temp.swap(dest);
}
}
else
{
matrix_assign_default(dest,src);
}
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
} // end of namespace blas_bindings } // end of namespace blas_bindings
...@@ -227,13 +519,14 @@ namespace dlib ...@@ -227,13 +519,14 @@ namespace dlib
inline typename enable_if_c<(is_same_type<T,float>::value || inline typename enable_if_c<(is_same_type<T,float>::value ||
is_same_type<T,double>::value || is_same_type<T,double>::value ||
is_same_type<T,std::complex<float> >::value || is_same_type<T,std::complex<float> >::value ||
is_same_type<T,std::complex<double> >::value) is_same_type<T,std::complex<double> >::value) &&
blas_bindings::has_matrix_multiply<src_exp>::value
>::type matrix_assign_big ( >::type matrix_assign_big (
matrix<T,NR,NC,MM,L>& dest, matrix<T,NR,NC,MM,L>& dest,
const src_exp& src const src_exp& src
) )
{ {
blas_bindings::matrix_assign_blas_helper<T,NR,NC,MM,L,src_exp>::assign(dest,src); blas_bindings::matrix_assign_blas(dest,src);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -44,6 +44,12 @@ namespace dlib ...@@ -44,6 +44,12 @@ namespace dlib
EXP1& dest, EXP1& dest,
const EXP2& src const EXP2& src
) )
/*!
requires
- src.destructively_aliases(dest) == false
ensures
- #dest == src
!*/
{ {
for (long r = 0; r < src.nr(); ++r) for (long r = 0; r < src.nr(); ++r)
{ {
...@@ -54,6 +60,83 @@ namespace dlib ...@@ -54,6 +60,83 @@ namespace dlib
} }
} }
// ----------------------------------------------------------------------------------------
template <typename EXP1, typename EXP2>
inline static void matrix_assign_default (
EXP1& dest,
const EXP2& src,
typename EXP2::type alpha,
bool add_to
)
/*!
requires
- src.destructively_aliases(dest) == false
ensures
- if (add_to == false) then
- #dest == alpha*src
- else
- #dest == dest + alpha*src
!*/
{
if (add_to)
{
if (alpha == 1)
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) += src(r,c);
}
}
}
else if (alpha == -1)
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) -= src(r,c);
}
}
}
else
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) += alpha*src(r,c);
}
}
}
}
else
{
if (alpha == 1)
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) = src(r,c);
}
}
}
else
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) = alpha*src(r,c);
}
}
}
}
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -83,7 +166,6 @@ namespace dlib ...@@ -83,7 +166,6 @@ namespace dlib
- src.destructively_aliases(dest) == false - src.destructively_aliases(dest) == false
ensures ensures
- #dest == src - #dest == src
- the part of dest outside the above sub matrix remains unchanged
!*/ !*/
{ {
// Call src.ref() here so that the derived type of the matrix_exp shows // Call src.ref() here so that the derived type of the matrix_exp shows
...@@ -107,7 +189,6 @@ namespace dlib ...@@ -107,7 +189,6 @@ namespace dlib
- src.destructively_aliases(dest) == false - src.destructively_aliases(dest) == false
ensures ensures
- #dest == src - #dest == src
- the part of dest outside the above sub matrix remains unchanged
!*/ !*/
{ {
matrix_assign_default(dest,src.ref()); matrix_assign_default(dest,src.ref());
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include "cblas.h" #include "cblas.h"
#endif #endif
#include <iostream>
namespace dlib namespace dlib
{ {
...@@ -219,45 +221,18 @@ namespace dlib ...@@ -219,45 +221,18 @@ namespace dlib
const int M = static_cast<int>(src.nr()); const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc()); const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc()); const int K = static_cast<int>(src.lhs.nc());
const T alpha = 1;
const T* A = &src.lhs(0,0); const T* A = &src.lhs(0,0);
const int lda = src.lhs.nc(); const int lda = src.lhs.nc();
const T* B = &src.rhs(0,0); const T* B = &src.rhs(0,0);
const int ldb = src.rhs.nc(); const int ldb = src.rhs.nc();
const T beta = 0; const T beta = add_to?1:0;
T* C = &dest(0,0); T* C = &dest(0,0);
const int ldc = src.nc(); const int ldc = src.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout,rm + rm*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const T alpha = 1;
const T* A = &src.rhs.lhs(0,0);
const int lda = src.rhs.lhs.nc();
const T* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const T beta = 1;
T* C = &dest(0,0);
const int ldc = src.rhs.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, trans(rm)*rm) DLIB_ADD_BLAS_BINDING(row_major_layout, trans(rm)*rm)
...@@ -268,94 +243,18 @@ namespace dlib ...@@ -268,94 +243,18 @@ namespace dlib
const int M = static_cast<int>(src.nr()); const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc()); const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc()); const int K = static_cast<int>(src.lhs.nc());
const T alpha = 1;
const T* A = &src.lhs.m(0,0); const T* A = &src.lhs.m(0,0);
const int lda = src.lhs.m.nc(); const int lda = src.lhs.m.nc();
const T* B = &src.rhs(0,0); const T* B = &src.rhs(0,0);
const int ldb = src.rhs.nc(); const int ldb = src.rhs.nc();
const T beta = 0; const T beta = add_to?1:0;
T* C = &dest(0,0); T* C = &dest(0,0);
const int ldc = src.nc(); const int ldc = src.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout, rm + s*trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc());
const int K = static_cast<int>(src.rhs.m.lhs.nc());
const T alpha = src.rhs.s;
const T* A = &src.rhs.m.lhs.m(0,0);
const int lda = src.rhs.m.lhs.m.nc();
const T* B = &src.rhs.m.rhs(0,0);
const int ldb = src.rhs.m.rhs.nc();
const T beta = 1;
T* C = &dest(0,0);
const int ldc = dest.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout, s*trans(rm)*rm)
{
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.m.nr());
const int N = static_cast<int>(src.m.nc());
const int K = static_cast<int>(src.m.lhs.nc());
const T alpha = src.s;
const T* A = &src.m.lhs.m(0,0);
const int lda = src.m.lhs.m.nc();
const T* B = &src.m.rhs(0,0);
const int ldb = src.m.rhs.nc();
const T beta = 0;
T* C = &dest(0,0);
const int ldc = dest.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout, rm + trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const T alpha = 1;
const T* A = &src.rhs.lhs.m(0,0);
const int lda = src.rhs.lhs.m.nc();
const T* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const T beta = 1;
T* C = &dest(0,0);
const int ldc = src.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
#endif // DLIB_USE_BLAS #endif // DLIB_USE_BLAS
} }
......
...@@ -23,18 +23,19 @@ namespace dlib ...@@ -23,18 +23,19 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
/*! This file defines the default_matrix_multiply() function. It is a function
that conforms to the following definition:
template < template <
typename matrix_dest_type, typename matrix_dest_type,
typename EXP1, typename EXP1,
typename EXP2 typename EXP2
> >
typename enable_if_c<ma::matrix_is_vector<EXP1>::value == true && ma::matrix_is_vector<EXP2>::value == true>::type void default_matrix_multiply (
default_matrix_multiply (
matrix_dest_type& dest, matrix_dest_type& dest,
const EXP1& lhs, const EXP1& lhs,
const EXP2& rhs const EXP2& rhs
); );
/*!
requires requires
- (lhs*rhs).destructively_aliases(dest) == false - (lhs*rhs).destructively_aliases(dest) == false
- dest.nr() == (lhs*rhs).nr() - dest.nr() == (lhs*rhs).nr()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment