Commit dea6bc04 authored by Davis King's avatar Davis King

- Added overloads to cause scalar multiplications to combine and percolate out

     of matrix multiplications.
   - Worked more on the optimized overloads that call BLAS functions.
   - Changed the code inside the matrix assignment overloads so that it works
     better with GCC's optimizer.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402761
parent d346bdcf
This diff is collapsed.
...@@ -308,6 +308,8 @@ namespace dlib ...@@ -308,6 +308,8 @@ namespace dlib
const static bool lhs_is_costly = matrix_traits<matrix_multiply_exp>::lhs_is_costly; const static bool lhs_is_costly = matrix_traits<matrix_multiply_exp>::lhs_is_costly;
const static bool rhs_is_costly = matrix_traits<matrix_multiply_exp>::rhs_is_costly; const static bool rhs_is_costly = matrix_traits<matrix_multiply_exp>::rhs_is_costly;
const static bool either_is_costly = lhs_is_costly || rhs_is_costly;
const static bool both_are_costly = lhs_is_costly && rhs_is_costly;
typedef typename conditional_matrix_temp<const LHS,lhs_is_costly == false>::type LHS_ref_type; typedef typename conditional_matrix_temp<const LHS,lhs_is_costly == false>::type LHS_ref_type;
typedef typename conditional_matrix_temp<const RHS,rhs_is_costly == false>::type RHS_ref_type; typedef typename conditional_matrix_temp<const RHS,rhs_is_costly == false>::type RHS_ref_type;
...@@ -387,6 +389,59 @@ namespace dlib ...@@ -387,6 +389,59 @@ namespace dlib
return matrix_multiply_exp<EXP1, EXP2>(m1.ref(), m2.ref()); return matrix_multiply_exp<EXP1, EXP2>(m1.ref(), m2.ref());
} }
template <typename M, bool use_reference = true>
class matrix_mul_scal_exp;
// -------------------------
// Now we declare some overloads that cause any scalar multiplications to percolate
// up and outside of any matrix multiplies. Note that we are using the non-reference containing
// mode of the matrix_mul_scal_exp object since we are passing in locally constructed matrix_multiply_exp
// objects. So the matrix_mul_scal_exp object will contain copies of matrix_multiply_exp objects
// rather than references to them. This could result in extra matrix copies if the matrix_multiply_exp
// decided it should evaluate any of its arguments. So we also try to not apply this percolating operation
// if the matrix_multiply_exp would contain a fully evaluated copy of the original matrix_mul_scal_exp
// expression.
//
// Also, the reason we want to apply this transformation in the first place is because it (1) makes
// the expressions going into matrix multiply expressions simpler and (2) it makes it a lot more
// straight forward to bind BLAS calls to matrix expressions involving scalar multiplies.
template < typename EXP1, typename EXP2 >
inline const typename disable_if_c< matrix_multiply_exp<matrix_mul_scal_exp<EXP1>, matrix_mul_scal_exp<EXP2> >::both_are_costly ,
matrix_mul_scal_exp<matrix_multiply_exp<EXP1, EXP2>,false> >::type operator* (
const matrix_mul_scal_exp<EXP1>& m1,
const matrix_mul_scal_exp<EXP2>& m2
)
{
typedef matrix_multiply_exp<EXP1, EXP2> exp1;
typedef matrix_mul_scal_exp<exp1,false> exp2;
return exp2(exp1(m1.m, m2.m), m1.s*m2.s);
}
template < typename EXP1, typename EXP2 >
inline const typename disable_if_c< matrix_multiply_exp<matrix_mul_scal_exp<EXP1>, EXP2 >::lhs_is_costly ,
matrix_mul_scal_exp<matrix_multiply_exp<EXP1, EXP2>,false> >::type operator* (
const matrix_mul_scal_exp<EXP1>& m1,
const matrix_exp<EXP2>& m2
)
{
typedef matrix_multiply_exp<EXP1, EXP2> exp1;
typedef matrix_mul_scal_exp<exp1,false> exp2;
return exp2(exp1(m1.m, m2.ref()), m1.s);
}
template < typename EXP1, typename EXP2 >
inline const typename disable_if_c< matrix_multiply_exp<EXP1, matrix_mul_scal_exp<EXP2> >::rhs_is_costly ,
matrix_mul_scal_exp<matrix_multiply_exp<EXP1, EXP2>,false> >::type operator* (
const matrix_exp<EXP1>& m1,
const matrix_mul_scal_exp<EXP2>& m2
)
{
typedef matrix_multiply_exp<EXP1, EXP2> exp1;
typedef matrix_mul_scal_exp<exp1,false> exp2;
return exp2(exp1(m1.ref(), m2.m), m2.s);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename LHS, typename RHS> template <typename LHS, typename RHS>
...@@ -687,23 +742,20 @@ namespace dlib ...@@ -687,23 +742,20 @@ namespace dlib
template < template <
typename EXP, typename EXP,
typename S typename S
> >
inline const matrix_div_scal_exp<EXP> operator/ ( inline const typename enable_if_c<std::numeric_limits<typename EXP::type>::is_integer, matrix_div_scal_exp<EXP> >::type operator/ (
const matrix_exp<EXP>& m, const matrix_exp<EXP>& m,
const S& s const S& s
) )
{ {
return matrix_div_scal_exp<EXP>(m.ref(),s); return matrix_div_scal_exp<EXP>(m.ref(),static_cast<typename EXP::type>(s));
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename M> template <typename M, bool use_reference >
class matrix_mul_scal_exp; struct matrix_traits<matrix_mul_scal_exp<M,use_reference> >
template <typename M>
struct matrix_traits<matrix_mul_scal_exp<M> >
{ {
typedef typename M::type type; typedef typename M::type type;
typedef typename M::mem_manager_type mem_manager_type; typedef typename M::mem_manager_type mem_manager_type;
...@@ -713,10 +765,15 @@ namespace dlib ...@@ -713,10 +765,15 @@ namespace dlib
const static long cost = M::cost+1; const static long cost = M::cost+1;
}; };
template <typename T, bool is_ref> struct conditional_reference { typedef T type; };
template <typename T> struct conditional_reference<T,true> { typedef T& type; };
template < template <
typename M typename M,
bool use_reference
> >
class matrix_mul_scal_exp : public matrix_exp<matrix_mul_scal_exp<M> > class matrix_mul_scal_exp : public matrix_exp<matrix_mul_scal_exp<M,use_reference> >
{ {
/*! /*!
REQUIREMENTS ON M REQUIREMENTS ON M
...@@ -773,7 +830,9 @@ namespace dlib ...@@ -773,7 +830,9 @@ namespace dlib
long nc ( long nc (
) const { return m.nc(); } ) const { return m.nc(); }
const M& m; typedef typename conditional_reference<const M,use_reference>::type M_ref_type;
M_ref_type m;
const type s; const type s;
}; };
...@@ -789,6 +848,19 @@ namespace dlib ...@@ -789,6 +848,19 @@ namespace dlib
return matrix_mul_scal_exp<EXP>(m.ref(),s); return matrix_mul_scal_exp<EXP>(m.ref(),s);
} }
template <
typename EXP,
typename S,
bool B
>
inline typename disable_if<is_matrix<S>, const matrix_mul_scal_exp<EXP> >::type operator* (
const matrix_mul_scal_exp<EXP,B>& m,
const S& s
)
{
return matrix_mul_scal_exp<EXP>(m.m,s*m.s);
}
template < template <
typename EXP, typename EXP,
typename S typename S
...@@ -802,36 +874,44 @@ namespace dlib ...@@ -802,36 +874,44 @@ namespace dlib
} }
template < template <
typename EXP typename EXP,
typename S,
bool B
> >
inline const matrix_mul_scal_exp<EXP> operator/ ( inline typename disable_if<is_matrix<S>, const matrix_mul_scal_exp<EXP> >::type operator* (
const matrix_exp<EXP>& m, const S& s,
const float& s const matrix_mul_scal_exp<EXP,B>& m
) )
{ {
return matrix_mul_scal_exp<EXP>(m.ref(),1.0f/s); return matrix_mul_scal_exp<EXP>(m.m,s*m.s);
} }
template < template <
typename EXP typename EXP ,
typename S
> >
inline const matrix_mul_scal_exp<EXP> operator/ ( inline const typename disable_if_c<std::numeric_limits<typename EXP::type>::is_integer, matrix_mul_scal_exp<EXP> >::type operator/ (
const matrix_exp<EXP>& m, const matrix_exp<EXP>& m,
const double& s const S& s
) )
{ {
return matrix_mul_scal_exp<EXP>(m.ref(),1.0/s); typedef typename EXP::type type;
const type one = 1;
return matrix_mul_scal_exp<EXP>(m.ref(),one/static_cast<type>(s));
} }
template < template <
typename EXP typename EXP,
bool B,
typename S
> >
inline const matrix_mul_scal_exp<EXP> operator/ ( inline const typename disable_if_c<std::numeric_limits<typename EXP::type>::is_integer, matrix_mul_scal_exp<EXP> >::type operator/ (
const matrix_exp<EXP>& m, const matrix_mul_scal_exp<EXP,B>& m,
const long double& s const S& s
) )
{ {
return matrix_mul_scal_exp<EXP>(m.ref(),1.0/s); typedef typename EXP::type type;
return matrix_mul_scal_exp<EXP>(m.m,m.s/static_cast<type>(s));
} }
template < template <
...@@ -844,6 +924,17 @@ namespace dlib ...@@ -844,6 +924,17 @@ namespace dlib
return matrix_mul_scal_exp<EXP>(m.ref(),-1); return matrix_mul_scal_exp<EXP>(m.ref(),-1);
} }
template <
typename EXP,
bool B
>
inline const matrix_mul_scal_exp<EXP> operator- (
const matrix_mul_scal_exp<EXP,B>& m
)
{
return matrix_mul_scal_exp<EXP>(m.m,-1*m.s);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -1261,8 +1352,18 @@ namespace dlib ...@@ -1261,8 +1352,18 @@ namespace dlib
(is_matrix<typename EXP::type>::value == true)); (is_matrix<typename EXP::type>::value == true));
if (m.destructively_aliases(*this) == false) if (m.destructively_aliases(*this) == false)
{ {
set_size(m.nr(),m.nc()); // This if statement is seemingly unnecessary since set_size() contains this
matrix_assign(*this, m); // exact same if statement. However, structuring the code this way causes
// gcc to handle the way it inlines this function in a much more favorable way.
if (data.nr() == m.nr() && data.nc() == m.nc())
{
matrix_assign(*this, m);
}
else
{
set_size(m.nr(),m.nc());
matrix_assign(*this, m);
}
} }
else else
{ {
......
...@@ -69,7 +69,7 @@ namespace dlib ...@@ -69,7 +69,7 @@ namespace dlib
struct same_exp<matrix_subtract_exp<Tlhs,Trhs>, matrix_subtract_exp<Ulhs,Urhs> > struct same_exp<matrix_subtract_exp<Tlhs,Trhs>, matrix_subtract_exp<Ulhs,Urhs> >
{ const static bool value = same_exp<Tlhs,Ulhs>::value && same_exp<Trhs,Urhs>::value; }; { const static bool value = same_exp<Tlhs,Ulhs>::value && same_exp<Trhs,Urhs>::value; };
template <typename T, typename U> struct same_exp<matrix_mul_scal_exp<T>, matrix_mul_scal_exp<U> > template <typename T, typename U, bool Tb, bool Ub> struct same_exp<matrix_mul_scal_exp<T,Tb>, matrix_mul_scal_exp<U,Ub> >
{ const static bool value = same_exp<T,U>::value; }; { const static bool value = same_exp<T,U>::value; };
template <typename T, typename U> struct same_exp<matrix_div_scal_exp<T>, matrix_div_scal_exp<U> > template <typename T, typename U> struct same_exp<matrix_div_scal_exp<T>, matrix_div_scal_exp<U> >
...@@ -113,13 +113,7 @@ namespace dlib ...@@ -113,13 +113,7 @@ namespace dlib
const EXP& src const EXP& src
) )
{ {
for (long r = 0; r < src.nr(); ++r) matrix_assign_default(dest,src);
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) = src(r,c);
}
}
} }
// If we know this is a matrix multiply then apply the // If we know this is a matrix multiply then apply the
...@@ -134,6 +128,75 @@ namespace dlib ...@@ -134,6 +128,75 @@ namespace dlib
set_all_elements(dest,0); set_all_elements(dest,0);
default_matrix_multiply(dest, src.lhs, src.rhs); default_matrix_multiply(dest, src.lhs, src.rhs);
} }
template <typename EXP1, typename EXP2>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L>, matrix_multiply_exp<EXP1,EXP2> >& src
)
{
if (&dest == &src.lhs)
{
default_matrix_multiply(dest, src.rhs.lhs, src.rhs.rhs);
}
else
{
dest = src.lhs;
default_matrix_multiply(dest, src.rhs.lhs, src.rhs.rhs);
}
}
template <typename EXP1, typename EXP2>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L>, matrix_add_exp<EXP1,EXP2> >& src
)
{
if (EXP1::cost > 50 || EXP2::cost > 5)
{
matrix_assign(dest, src.lhs + src.rhs.lhs);
matrix_assign(dest, src.lhs + src.rhs.rhs);
}
else
{
matrix_assign_default(dest,src);
}
}
template <typename EXP2>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L>,EXP2>& src
)
{
if (EXP2::cost > 50 && &dest != &src.lhs)
{
dest = src.lhs;
matrix_assign(dest, dest + src.rhs);
}
else
{
matrix_assign_default(dest,src);
}
}
template <typename EXP1, typename EXP2>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<EXP1,EXP2>& src
)
{
if (EXP1::cost > 50 || EXP2::cost > 50)
{
matrix_assign(dest,src.lhs);
matrix_assign(dest, dest + src.rhs);
}
else
{
matrix_assign_default(dest,src);
}
}
}; };
// This is a macro to help us add overloads for the matrix_assign_blas_helper template. // This is a macro to help us add overloads for the matrix_assign_blas_helper template.
......
...@@ -39,27 +39,10 @@ namespace dlib ...@@ -39,27 +39,10 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// In newer versions of GCC it is necessary to explicitly tell it to not try to template <typename EXP1, typename EXP2>
// inline the matrix_assign() function when working with matrix objects that inline static void matrix_assign_default (
// don't have dimensions that are known at compile time. Doing this makes the EXP1& dest,
// resulting binaries a lot faster when -O3 is used. This whole deal with const EXP2& src
// different versions of matrix_assign() is just to support getting the right
// inline behavior out of GCC.
#ifdef __GNUC__
#define DLIB_DONT_INLINE __attribute__((noinline))
#define DLIB_ALWAYS_INLINE __attribute__((always_inline))
#else
#define DLIB_DONT_INLINE
#define DLIB_ALWAYS_INLINE
#endif
template <
typename matrix_dest_type,
typename src_exp
>
DLIB_DONT_INLINE void matrix_assign_big (
matrix_dest_type& dest,
const matrix_exp<src_exp>& src
) )
{ {
for (long r = 0; r < src.nr(); ++r) for (long r = 0; r < src.nr(); ++r)
...@@ -71,22 +54,18 @@ namespace dlib ...@@ -71,22 +54,18 @@ namespace dlib
} }
} }
// ----------------------------------------------------------------------------------------
template < template <
typename matrix_dest_type, typename matrix_dest_type,
typename src_exp typename src_exp
> >
inline void matrix_assign_small ( void matrix_assign_big (
matrix_dest_type& dest, matrix_dest_type& dest,
const matrix_exp<src_exp>& src const matrix_exp<src_exp>& src
) )
{ {
for (long r = 0; r < src.nr(); ++r) matrix_assign_default(dest,src);
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) = src(r,c);
}
}
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -131,7 +110,7 @@ namespace dlib ...@@ -131,7 +110,7 @@ namespace dlib
- the part of dest outside the above sub matrix remains unchanged - the part of dest outside the above sub matrix remains unchanged
!*/ !*/
{ {
matrix_assign_small(dest,src.ref()); matrix_assign_default(dest,src.ref());
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "matrix_assign.h" #include "matrix_assign.h"
#ifdef DLIB_FOUND_BLAS #ifdef DLIB_FOUND_BLAS
#include "mkl_cblas.h" #include "cblas.h"
#endif #endif
namespace dlib namespace dlib
...@@ -32,9 +32,9 @@ namespace dlib ...@@ -32,9 +32,9 @@ namespace dlib
extern matrix<double,0,0,mm,column_major_layout> cm; // general matrix with column major order extern matrix<double,0,0,mm,column_major_layout> cm; // general matrix with column major order
extern matrix<double,1,0> rv; // general row vector extern matrix<double,1,0> rv; // general row vector
extern matrix<double,0,1> cv; // general column vector extern matrix<double,0,1> cv; // general column vector
extern const double s;
using namespace std;
#ifdef DLIB_FOUND_BLAS #ifdef DLIB_FOUND_BLAS
...@@ -59,6 +59,33 @@ namespace dlib ...@@ -59,6 +59,33 @@ namespace dlib
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(double, row_major_layout,rm + rm*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const double alpha = 1;
const double* A = &src.rhs.lhs(0,0);
const int lda = src.rhs.lhs.nc();
const double* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const double beta = 1;
double* C = &dest(0,0);
const int ldc = src.rhs.nc();
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(double, row_major_layout, trans(rm)*rm) DLIB_ADD_BLAS_BINDING(double, row_major_layout, trans(rm)*rm)
{ {
...@@ -81,8 +108,84 @@ namespace dlib ...@@ -81,8 +108,84 @@ namespace dlib
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(double, row_major_layout, rm + s*trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc());
const int K = static_cast<int>(src.rhs.m.lhs.nc());
const double alpha = src.rhs.s;
const double* A = &src.rhs.m.lhs.m(0,0);
const int lda = src.rhs.m.lhs.m.nc();
const double* B = &src.rhs.m.rhs(0,0);
const int ldb = src.rhs.m.rhs.nc();
const double beta = 1;
double* C = &dest(0,0);
const int ldc = dest.nc();
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(double, row_major_layout, s*trans(rm)*rm)
{
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.m.nr());
const int N = static_cast<int>(src.m.nc());
const int K = static_cast<int>(src.m.lhs.nc());
const double alpha = src.s;
const double* A = &src.m.lhs.m(0,0);
const int lda = src.m.lhs.m.nc();
const double* B = &src.m.rhs(0,0);
const int ldb = src.m.rhs.nc();
const double beta = 0;
double* C = &dest(0,0);
const int ldc = dest.nc();
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(double, row_major_layout, rm + trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const double alpha = 1;
const double* A = &src.rhs.lhs.m(0,0);
const int lda = src.rhs.lhs.m.nc();
const double* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const double beta = 1;
double* C = &dest(0,0);
const int ldc = src.nc();
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// ---------------------------------------------------------------------
// -------------------------- float overloads -------------------------- // -------------------------- float overloads --------------------------
// ---------------------------------------------------------------------
DLIB_ADD_BLAS_BINDING(float, row_major_layout, rm*rm) DLIB_ADD_BLAS_BINDING(float, row_major_layout, rm*rm)
{ {
...@@ -105,6 +208,33 @@ namespace dlib ...@@ -105,6 +208,33 @@ namespace dlib
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(float, row_major_layout,rm + rm*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const float alpha = 1;
const float* A = &src.rhs.lhs(0,0);
const int lda = src.rhs.lhs.nc();
const float* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const float beta = 1;
float* C = &dest(0,0);
const int ldc = src.rhs.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(float, row_major_layout, trans(rm)*rm) DLIB_ADD_BLAS_BINDING(float, row_major_layout, trans(rm)*rm)
{ {
...@@ -127,6 +257,80 @@ namespace dlib ...@@ -127,6 +257,80 @@ namespace dlib
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(float, row_major_layout, rm + s*trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc());
const int K = static_cast<int>(src.rhs.m.lhs.nc());
const float alpha = src.rhs.s;
const float* A = &src.rhs.m.lhs.m(0,0);
const int lda = src.rhs.m.lhs.m.nc();
const float* B = &src.rhs.m.rhs(0,0);
const int ldb = src.rhs.m.rhs.nc();
const float beta = 1;
float* C = &dest(0,0);
const int ldc = dest.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(float, row_major_layout, s*trans(rm)*rm)
{
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.m.nr());
const int N = static_cast<int>(src.m.nc());
const int K = static_cast<int>(src.m.lhs.nc());
const float alpha = src.s;
const float* A = &src.m.lhs.m(0,0);
const int lda = src.m.lhs.m.nc();
const float* B = &src.m.rhs(0,0);
const int ldb = src.m.rhs.nc();
const float beta = 0;
float* C = &dest(0,0);
const int ldc = dest.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(float, row_major_layout, rm + trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const float alpha = 1;
const float* A = &src.rhs.lhs.m(0,0);
const int lda = src.rhs.lhs.m.nc();
const float* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const float beta = 1;
float* C = &dest(0,0);
const int ldc = src.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
#endif // DLIB_FOUND_BLAS #endif // DLIB_FOUND_BLAS
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment