Commit 13a920d5 authored by Davis King's avatar Davis King

Cleaned up the BLAS bindings more.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402762
parent dea6bc04
...@@ -201,14 +201,14 @@ namespace dlib ...@@ -201,14 +201,14 @@ namespace dlib
// This is a macro to help us add overloads for the matrix_assign_blas_helper template. // This is a macro to help us add overloads for the matrix_assign_blas_helper template.
// Using this macro it is easy to add overloads for arbitrary matrix expressions. // Using this macro it is easy to add overloads for arbitrary matrix expressions.
#define DLIB_ADD_BLAS_BINDING( dest_type, dest_layout, src_expression) \ #define DLIB_ADD_BLAS_BINDING( dest_layout, src_expression) \
template <typename T> struct BOOST_JOIN(blas,__LINE__) \ template <typename T> struct BOOST_JOIN(blas,__LINE__) \
{ const static bool value = sizeof(yes_type) == sizeof(test<T>(src_expression)); }; \ { const static bool value = sizeof(yes_type) == sizeof(test<T>(src_expression)); }; \
template < long NR, long NC, typename MM, typename src_exp > \ template < typename T, long NR, long NC, typename MM, typename src_exp > \
struct matrix_assign_blas_helper<dest_type,NR,NC,MM,dest_layout, src_exp, \ struct matrix_assign_blas_helper<T,NR,NC,MM,dest_layout, src_exp, \
typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp> >::type > { \ typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp> >::type > { \
static void assign ( \ static void assign ( \
matrix<dest_type,NR,NC,MM,dest_layout>& dest, \ matrix<T,NR,NC,MM,dest_layout>& dest, \
const src_exp& src \ const src_exp& src \
) { ) {
...@@ -224,7 +224,11 @@ namespace dlib ...@@ -224,7 +224,11 @@ namespace dlib
typename T, long NR, long NC, typename MM, typename L, typename T, long NR, long NC, typename MM, typename L,
typename src_exp typename src_exp
> >
inline void matrix_assign_big ( inline typename enable_if_c<(is_same_type<T,float>::value ||
is_same_type<T,double>::value ||
is_same_type<T,std::complex<float> >::value ||
is_same_type<T,std::complex<double> >::value)
>::type matrix_assign_big (
matrix<T,NR,NC,MM,L>& dest, matrix<T,NR,NC,MM,L>& dest,
const src_exp& src const src_exp& src
) )
......
...@@ -16,178 +16,202 @@ namespace dlib ...@@ -16,178 +16,202 @@ namespace dlib
namespace blas_bindings namespace blas_bindings
{ {
// ----------------------------------------------------------------------------------------
// Here we declare some matrix objects for use in the DLIB_ADD_BLAS_BINDING macro. These
// extern declarations don't actually correspond to any real matrix objects. They are
// simply here so we can build matrix expressions with the DLIB_ADD_BLAS_BINDING marco.
typedef memory_manager<char>::kernel_1a mm;
// Note that the fact that these are double matrices isn't important. The type
// that matters is the one that is the first argument of the DLIB_ADD_BLAS_BINDING.
// That type determines what the type of the elements of the matrices that we
// are dealing with is.
extern matrix<double,0,0,mm,row_major_layout> rm; // general matrix with row major order
extern matrix<double,0,0,mm,column_major_layout> cm; // general matrix with column major order
extern matrix<double,1,0> rv; // general row vector
extern matrix<double,0,1> cv; // general column vector
extern const double s;
#ifdef DLIB_FOUND_BLAS #ifdef DLIB_FOUND_BLAS
DLIB_ADD_BLAS_BINDING(double, row_major_layout, rm*rm) // ----------------------------------------------------------------------------------------
{ // ----------------------------------------------------------------------------------------
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const double alpha = 1;
const double* A = &src.lhs(0,0);
const int lda = src.lhs.nc();
const double* B = &src.rhs(0,0);
const int ldb = src.rhs.nc();
const double beta = 0; inline void cblas_gemm( const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
double* C = &dest(0,0); const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int ldc = src.nc(); const int K, const float alpha, const float *A,
const int lda, const float *B, const int ldb,
const float beta, float *C, const int ldc)
{
cblas_sgemm( Order, TransA, TransB, M, N,
K, alpha, A, lda, B, ldb, beta, C, ldc);
}
inline void cblas_gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const double alpha, const double *A,
const int lda, const double *B, const int ldb,
const double beta, double *C, const int ldc)
{
cblas_dgemm( Order, TransA, TransB, M, N,
K, alpha, A, lda, B, ldb, beta, C, ldc);
}
inline void cblas_gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const std::complex<float> *alpha, const std::complex<float> *A,
const int lda, const std::complex<float> *B, const int ldb,
const std::complex<float> *beta, std::complex<float> *C, const int ldc)
{
cblas_cgemm( Order, TransA, TransB, M, N,
K, alpha, A, lda, B, ldb, beta, C, ldc);
}
inline void cblas_gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const std::complex<double> *alpha, const std::complex<double> *A,
const int lda, const std::complex<double> *B, const int ldb,
const std::complex<double> *beta, std::complex<double> *C, const int ldc)
{
cblas_zgemm( Order, TransA, TransB, M, N,
K, alpha, A, lda, B, ldb, beta, C, ldc);
}
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); // ----------------------------------------------------------------------------------------
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(double, row_major_layout,rm + rm*rm) inline void cblas_gemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *X, const int incX, const float beta,
float *Y, const int incY)
{ {
if (&src.lhs != &dest) cblas_sgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
{ }
dest = src.lhs;
} inline void cblas_gemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *X, const int incX, const double beta,
double *Y, const int incY)
{
cblas_dgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
}
inline void cblas_gemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const std::complex<float> *alpha, const std::complex<float> *A, const int lda,
const std::complex<float> *X, const int incX, const std::complex<float> *beta,
std::complex<float> *Y, const int incY)
{
cblas_cgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
}
inline void cblas_gemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const std::complex<double> *alpha, const std::complex<double> *A, const int lda,
const std::complex<double> *X, const int incX, const std::complex<double> *beta,
std::complex<double> *Y, const int incY)
{
cblas_zgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
}
const CBLAS_ORDER Order = CblasRowMajor; // ----------------------------------------------------------------------------------------
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const double alpha = 1;
const double* A = &src.rhs.lhs(0,0);
const int lda = src.rhs.lhs.nc();
const double* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const double beta = 1; inline void cblas_ger(const enum CBLAS_ORDER order, const int M, const int N,
double* C = &dest(0,0); const std::complex<float> *alpha, const std::complex<float> *X, const int incX,
const int ldc = src.rhs.nc(); const std::complex<float> *Y, const int incY, std::complex<float> *A, const int lda)
{
cblas_cgeru (order, M, N, alpha, X, incX, Y, incY, A, lda);
}
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); inline void cblas_ger(const enum CBLAS_ORDER order, const int M, const int N,
} DLIB_END_BLAS_BINDING const std::complex<double> *alpha, const std::complex<double> *X, const int incX,
const std::complex<double> *Y, const int incY, std::complex<double> *A, const int lda)
{
cblas_zgeru (order, M, N, alpha, X, incX, Y, incY, A, lda);
}
// -------------------------------------- inline void cblas_ger(const enum CBLAS_ORDER order, const int M, const int N,
const float alpha, const float *X, const int incX,
const float *Y, const int incY, float *A, const int lda)
{
cblas_sger (order, M, N, alpha, X, incX, Y, incY, A, lda);
}
DLIB_ADD_BLAS_BINDING(double, row_major_layout, trans(rm)*rm) inline void cblas_ger(const enum CBLAS_ORDER order, const int M, const int N,
const double alpha, const double *X, const int incX,
const double *Y, const int incY, double *A, const int lda)
{ {
const CBLAS_ORDER Order = CblasRowMajor; cblas_dger (order, M, N, alpha, X, incX, Y, incY, A, lda);
const CBLAS_TRANSPOSE TransA = CblasTrans; }
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const double alpha = 1;
const double* A = &src.lhs.m(0,0);
const int lda = src.lhs.m.nc();
const double* B = &src.rhs(0,0);
const int ldb = src.rhs.nc();
const double beta = 0; // ----------------------------------------------------------------------------------------
double* C = &dest(0,0);
const int ldc = src.nc();
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); inline void cblas_gerc(const enum CBLAS_ORDER order, const int M, const int N,
} DLIB_END_BLAS_BINDING const std::complex<float> *alpha, const std::complex<float> *X, const int incX,
const std::complex<float> *Y, const int incY, std::complex<float> *A, const int lda)
{
cblas_cgerc (order, M, N, alpha, X, incX, Y, incY, A, lda);
}
DLIB_ADD_BLAS_BINDING(double, row_major_layout, rm + s*trans(rm)*rm) inline void cblas_gerc(const enum CBLAS_ORDER order, const int M, const int N,
const std::complex<double> *alpha, const std::complex<double> *X, const int incX,
const std::complex<double> *Y, const int incY, std::complex<double> *A, const int lda)
{ {
if (&src.lhs != &dest) cblas_zgerc (order, M, N, alpha, X, incX, Y, incY, A, lda);
{ }
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor; // ----------------------------------------------------------------------------------------
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc());
const int K = static_cast<int>(src.rhs.m.lhs.nc());
const double alpha = src.rhs.s;
const double* A = &src.rhs.m.lhs.m(0,0);
const int lda = src.rhs.m.lhs.m.nc();
const double* B = &src.rhs.m.rhs(0,0);
const int ldb = src.rhs.m.rhs.nc();
const double beta = 1; inline float cblas_dot(const int N, const float *X, const int incX,
double* C = &dest(0,0); const float *Y, const int incY)
const int ldc = dest.nc(); {
return cblas_sdot(N, X, incX, Y, incY);
}
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); inline double cblas_dot(const int N, const double *X, const int incX,
} DLIB_END_BLAS_BINDING const double *Y, const int incY)
{
return cblas_ddot(N, X, incX, Y, incY);
}
DLIB_ADD_BLAS_BINDING(double, row_major_layout, s*trans(rm)*rm) inline std::complex<float> cblas_dot(const int N, const std::complex<float> *X, const int incX,
const std::complex<float> *Y, const int incY)
{ {
const CBLAS_ORDER Order = CblasRowMajor; std::complex<float> result;
const CBLAS_TRANSPOSE TransA = CblasTrans; cblas_cdotu_sub(N, X, incX, Y, incY, &result);
const CBLAS_TRANSPOSE TransB = CblasNoTrans; return result;
const int M = static_cast<int>(src.m.nr()); }
const int N = static_cast<int>(src.m.nc());
const int K = static_cast<int>(src.m.lhs.nc());
const double alpha = src.s;
const double* A = &src.m.lhs.m(0,0);
const int lda = src.m.lhs.m.nc();
const double* B = &src.m.rhs(0,0);
const int ldb = src.m.rhs.nc();
const double beta = 0; inline std::complex<double> cblas_dot(const int N, const std::complex<double> *X, const int incX,
double* C = &dest(0,0); const std::complex<double> *Y, const int incY)
const int ldc = dest.nc(); {
std::complex<double> result;
cblas_zdotu_sub(N, X, incX, Y, incY, &result);
return result;
}
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); // ----------------------------------------------------------------------------------------
} DLIB_END_BLAS_BINDING
inline std::complex<float> cblas_dotc(const int N, const std::complex<float> *X, const int incX,
const std::complex<float> *Y, const int incY)
{
std::complex<float> result;
cblas_cdotc_sub(N, X, incX, Y, incY, &result);
return result;
}
DLIB_ADD_BLAS_BINDING(double, row_major_layout, rm + trans(rm)*rm) inline std::complex<double> cblas_dotc(const int N, const std::complex<double> *X, const int incX,
const std::complex<double> *Y, const int incY)
{ {
if (&src.lhs != &dest) std::complex<double> result;
{ cblas_zdotc_sub(N, X, incX, Y, incY, &result);
dest = src.lhs; return result;
} }
const CBLAS_ORDER Order = CblasRowMajor; // ----------------------------------------------------------------------------------------
const CBLAS_TRANSPOSE TransA = CblasTrans; // ----------------------------------------------------------------------------------------
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const double alpha = 1;
const double* A = &src.rhs.lhs.m(0,0);
const int lda = src.rhs.lhs.m.nc();
const double* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const double beta = 1; // Here we declare some matrix objects for use in the DLIB_ADD_BLAS_BINDING macro. These
double* C = &dest(0,0); // extern declarations don't actually correspond to any real matrix objects. They are
const int ldc = src.nc(); // simply here so we can build matrix expressions with the DLIB_ADD_BLAS_BINDING marco.
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
typedef memory_manager<char>::kernel_1a mm;
// Note that the fact that these are double matrices isn't important, it is just a placeholder in this case.
extern matrix<double,0,0,mm,row_major_layout> rm; // general matrix with row major order
extern matrix<double,0,0,mm,column_major_layout> cm; // general matrix with column major order
extern matrix<double,1,0> rv; // general row vector
extern matrix<double,0,1> cv; // general column vector
extern const double s;
// --------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// -------------------------- float overloads --------------------------
// ---------------------------------------------------------------------
DLIB_ADD_BLAS_BINDING(float, row_major_layout, rm*rm) DLIB_ADD_BLAS_BINDING(row_major_layout, rm*rm)
{ {
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
...@@ -195,20 +219,20 @@ namespace dlib ...@@ -195,20 +219,20 @@ namespace dlib
const int M = static_cast<int>(src.nr()); const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc()); const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc()); const int K = static_cast<int>(src.lhs.nc());
const float alpha = 1; const T alpha = 1;
const float* A = &src.lhs(0,0); const T* A = &src.lhs(0,0);
const int lda = src.lhs.nc(); const int lda = src.lhs.nc();
const float* B = &src.rhs(0,0); const T* B = &src.rhs(0,0);
const int ldb = src.rhs.nc(); const int ldb = src.rhs.nc();
const float beta = 0; const T beta = 0;
float* C = &dest(0,0); T* C = &dest(0,0);
const int ldc = src.nc(); const int ldc = src.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(float, row_major_layout,rm + rm*rm) DLIB_ADD_BLAS_BINDING(row_major_layout,rm + rm*rm)
{ {
if (&src.lhs != &dest) if (&src.lhs != &dest)
{ {
...@@ -221,22 +245,22 @@ namespace dlib ...@@ -221,22 +245,22 @@ namespace dlib
const int M = static_cast<int>(src.rhs.nr()); const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc()); const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc()); const int K = static_cast<int>(src.rhs.lhs.nc());
const float alpha = 1; const T alpha = 1;
const float* A = &src.rhs.lhs(0,0); const T* A = &src.rhs.lhs(0,0);
const int lda = src.rhs.lhs.nc(); const int lda = src.rhs.lhs.nc();
const float* B = &src.rhs.rhs(0,0); const T* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc(); const int ldb = src.rhs.rhs.nc();
const float beta = 1; const T beta = 1;
float* C = &dest(0,0); T* C = &dest(0,0);
const int ldc = src.rhs.nc(); const int ldc = src.rhs.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(float, row_major_layout, trans(rm)*rm) DLIB_ADD_BLAS_BINDING(row_major_layout, trans(rm)*rm)
{ {
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
...@@ -244,20 +268,20 @@ namespace dlib ...@@ -244,20 +268,20 @@ namespace dlib
const int M = static_cast<int>(src.nr()); const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc()); const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc()); const int K = static_cast<int>(src.lhs.nc());
const float alpha = 1; const T alpha = 1;
const float* A = &src.lhs.m(0,0); const T* A = &src.lhs.m(0,0);
const int lda = src.lhs.m.nc(); const int lda = src.lhs.m.nc();
const float* B = &src.rhs(0,0); const T* B = &src.rhs(0,0);
const int ldb = src.rhs.nc(); const int ldb = src.rhs.nc();
const float beta = 0; const T beta = 0;
float* C = &dest(0,0); T* C = &dest(0,0);
const int ldc = src.nc(); const int ldc = src.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(float, row_major_layout, rm + s*trans(rm)*rm) DLIB_ADD_BLAS_BINDING(row_major_layout, rm + s*trans(rm)*rm)
{ {
if (&src.lhs != &dest) if (&src.lhs != &dest)
{ {
...@@ -270,20 +294,20 @@ namespace dlib ...@@ -270,20 +294,20 @@ namespace dlib
const int M = static_cast<int>(src.rhs.m.nr()); const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc()); const int N = static_cast<int>(src.rhs.m.nc());
const int K = static_cast<int>(src.rhs.m.lhs.nc()); const int K = static_cast<int>(src.rhs.m.lhs.nc());
const float alpha = src.rhs.s; const T alpha = src.rhs.s;
const float* A = &src.rhs.m.lhs.m(0,0); const T* A = &src.rhs.m.lhs.m(0,0);
const int lda = src.rhs.m.lhs.m.nc(); const int lda = src.rhs.m.lhs.m.nc();
const float* B = &src.rhs.m.rhs(0,0); const T* B = &src.rhs.m.rhs(0,0);
const int ldb = src.rhs.m.rhs.nc(); const int ldb = src.rhs.m.rhs.nc();
const float beta = 1; const T beta = 1;
float* C = &dest(0,0); T* C = &dest(0,0);
const int ldc = dest.nc(); const int ldc = dest.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(float, row_major_layout, s*trans(rm)*rm) DLIB_ADD_BLAS_BINDING(row_major_layout, s*trans(rm)*rm)
{ {
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
...@@ -291,21 +315,21 @@ namespace dlib ...@@ -291,21 +315,21 @@ namespace dlib
const int M = static_cast<int>(src.m.nr()); const int M = static_cast<int>(src.m.nr());
const int N = static_cast<int>(src.m.nc()); const int N = static_cast<int>(src.m.nc());
const int K = static_cast<int>(src.m.lhs.nc()); const int K = static_cast<int>(src.m.lhs.nc());
const float alpha = src.s; const T alpha = src.s;
const float* A = &src.m.lhs.m(0,0); const T* A = &src.m.lhs.m(0,0);
const int lda = src.m.lhs.m.nc(); const int lda = src.m.lhs.m.nc();
const float* B = &src.m.rhs(0,0); const T* B = &src.m.rhs(0,0);
const int ldb = src.m.rhs.nc(); const int ldb = src.m.rhs.nc();
const float beta = 0; const T beta = 0;
float* C = &dest(0,0); T* C = &dest(0,0);
const int ldc = dest.nc(); const int ldc = dest.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(float, row_major_layout, rm + trans(rm)*rm) DLIB_ADD_BLAS_BINDING(row_major_layout, rm + trans(rm)*rm)
{ {
if (&src.lhs != &dest) if (&src.lhs != &dest)
{ {
...@@ -318,17 +342,17 @@ namespace dlib ...@@ -318,17 +342,17 @@ namespace dlib
const int M = static_cast<int>(src.rhs.nr()); const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc()); const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc()); const int K = static_cast<int>(src.rhs.lhs.nc());
const float alpha = 1; const T alpha = 1;
const float* A = &src.rhs.lhs.m(0,0); const T* A = &src.rhs.lhs.m(0,0);
const int lda = src.rhs.lhs.m.nc(); const int lda = src.rhs.lhs.m.nc();
const float* B = &src.rhs.rhs(0,0); const T* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc(); const int ldb = src.rhs.rhs.nc();
const float beta = 1; const T beta = 1;
float* C = &dest(0,0); T* C = &dest(0,0);
const int ldc = src.nc(); const int ldc = src.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment