Commit d1c92fb8 authored by Davis King's avatar Davis King

Simplified the BLAS binding macro.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402767
parent 88dea115
...@@ -71,45 +71,51 @@ namespace dlib ...@@ -71,45 +71,51 @@ namespace dlib
/*! These two matrices are the same if they are either: /*! These two matrices are the same if they are either:
- both row vectors - both row vectors
- both column vectors - both column vectors
- both general matrices with the same kind of layout type - both general non-vector matrices
!*/ !*/
const static bool value = (NR1 == 1 && NR2 == 1) || const static bool value = (NR1 == 1 && NR2 == 1) ||
(NC1==1 && NC2==1) || (NC1==1 && NC2==1) ||
(NR1!=1 && NC1!=1 && NR2!=1 && NC2!=1 && is_same_type<L1,L2>::value); (NR1!=1 && NC1!=1 && NR2!=1 && NC2!=1);
}; };
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// This template struct is used to tell us if two matrix expressions both contain the same // This template struct is used to tell us if two matrix expressions both contain the same
// sequence of operators, expressions, and work on matrices laid out in memory in compatible ways. // sequence of operators, expressions. It also only has a value of true if the T expression
template <typename T, typename U> // contains only matrices with the given layout.
template <typename T, typename U, typename layout>
struct same_exp struct same_exp
{ {
const static bool value = is_same_type<typename T::exp_type, typename U::exp_type>::value || const static bool value = (is_same_type<typename T::exp_type, typename U::exp_type>::value ||
same_matrix<typename T::exp_type, typename U::exp_type>::value;; same_matrix<typename T::exp_type, typename U::exp_type>::value) &&
is_same_type<typename T::layout_type,layout>::value;
}; };
template <typename Tlhs, typename Ulhs, typename Trhs, typename Urhs> template <typename Tlhs, typename Ulhs, typename Trhs, typename Urhs, typename layout>
struct same_exp<matrix_multiply_exp<Tlhs,Trhs>, matrix_multiply_exp<Ulhs,Urhs> > struct same_exp<matrix_multiply_exp<Tlhs,Trhs>, matrix_multiply_exp<Ulhs,Urhs>,layout >
{ const static bool value = same_exp<Tlhs,Ulhs>::value && same_exp<Trhs,Urhs>::value; }; { const static bool value = same_exp<Tlhs,Ulhs,layout>::value && same_exp<Trhs,Urhs,layout>::value; };
template <typename Tlhs, typename Ulhs, typename Trhs, typename Urhs> template <typename Tlhs, typename Ulhs, typename Trhs, typename Urhs, typename layout>
struct same_exp<matrix_add_exp<Tlhs,Trhs>, matrix_add_exp<Ulhs,Urhs> > struct same_exp<matrix_add_exp<Tlhs,Trhs>, matrix_add_exp<Ulhs,Urhs>, layout >
{ const static bool value = same_exp<Tlhs,Ulhs>::value && same_exp<Trhs,Urhs>::value; }; { const static bool value = same_exp<Tlhs,Ulhs,layout>::value && same_exp<Trhs,Urhs,layout>::value; };
template <typename Tlhs, typename Ulhs, typename Trhs, typename Urhs> template <typename Tlhs, typename Ulhs, typename Trhs, typename Urhs, typename layout>
struct same_exp<matrix_subtract_exp<Tlhs,Trhs>, matrix_subtract_exp<Ulhs,Urhs> > struct same_exp<matrix_subtract_exp<Tlhs,Trhs>, matrix_subtract_exp<Ulhs,Urhs>, layout >
{ const static bool value = same_exp<Tlhs,Ulhs>::value && same_exp<Trhs,Urhs>::value; }; { const static bool value = same_exp<Tlhs,Ulhs,layout>::value && same_exp<Trhs,Urhs,layout>::value; };
template <typename T, typename U, bool Tb, bool Ub> struct same_exp<matrix_mul_scal_exp<T,Tb>, matrix_mul_scal_exp<U,Ub> > template <typename T, typename U, bool Tb, bool Ub, typename layout>
{ const static bool value = same_exp<T,U>::value; }; struct same_exp<matrix_mul_scal_exp<T,Tb>, matrix_mul_scal_exp<U,Ub>, layout >
{ const static bool value = same_exp<T,U,layout>::value; };
template <typename T, typename U> struct same_exp<matrix_div_scal_exp<T>, matrix_div_scal_exp<U> > template <typename T, typename U, typename layout>
{ const static bool value = same_exp<T,U>::value; }; struct same_exp<matrix_div_scal_exp<T>, matrix_div_scal_exp<U>, layout >
{ const static bool value = same_exp<T,U,layout>::value; };
template <typename T, typename U, typename OP> struct same_exp<matrix_unary_exp<T,OP>, matrix_unary_exp<U,OP> > template <typename T, typename U, typename OP, typename layout>
{ const static bool value = same_exp<T,U>::value; }; struct same_exp<matrix_unary_exp<T,OP>, matrix_unary_exp<U,OP>, layout >
{ const static bool value = same_exp<T,U,layout>::value; };
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
...@@ -123,10 +129,10 @@ namespace dlib ...@@ -123,10 +129,10 @@ namespace dlib
}; };
// This is a helper that is used below to apply the same_exp template to matrix expressions. // This is a helper that is used below to apply the same_exp template to matrix expressions.
template <typename T, typename U> template <typename T, typename layout, typename U>
typename enable_if<same_exp<T,U>,yes_type>::type test(U); typename enable_if<same_exp<T,U,layout>,yes_type>::type test(U);
template <typename T, typename U> template <typename T, typename layout, typename U>
typename disable_if<same_exp<T,U>,no_type>::type test(U); typename disable_if<same_exp<T,U,layout>,no_type>::type test(U);
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
...@@ -197,18 +203,19 @@ namespace dlib ...@@ -197,18 +203,19 @@ namespace dlib
// This is a macro to help us add overloads for the matrix_assign_blas_helper template. // This is a macro to help us add overloads for the matrix_assign_blas_helper template.
// Using this macro it is easy to add overloads for arbitrary matrix expressions. // Using this macro it is easy to add overloads for arbitrary matrix expressions.
#define DLIB_ADD_BLAS_BINDING( dest_layout, src_expression) \ #define DLIB_ADD_BLAS_BINDING(src_expression) \
template <typename T> struct BOOST_JOIN(blas,__LINE__) \ template <typename T, typename L> struct BOOST_JOIN(blas,__LINE__) \
{ const static bool value = sizeof(yes_type) == sizeof(test<T>(src_expression)); }; \ { const static bool value = sizeof(yes_type) == sizeof(test<T,L>(src_expression)); }; \
template < typename T, long NR, long NC, typename MM, typename src_exp > \ template < typename T, long NR, long NC, typename MM, typename L, typename src_exp >\
struct matrix_assign_blas_helper<T,NR,NC,MM,dest_layout, src_exp, \ struct matrix_assign_blas_helper<T,NR,NC,MM,L, src_exp, \
typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp> >::type > { \ typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp,L> >::type > { \
static void assign ( \ static void assign ( \
matrix<T,NR,NC,MM,dest_layout>& dest, \ matrix<T,NR,NC,MM,L>& dest, \
const src_exp& src, \ const src_exp& src, \
typename src_exp::type alpha, \ typename src_exp::type alpha, \
bool add_to \ bool add_to \
) { ) { \
const bool is_row_major_order = is_same_type<L,row_major_layout>::value;
#define DLIB_END_BLAS_BINDING }}; #define DLIB_END_BLAS_BINDING }};
......
...@@ -11,7 +11,8 @@ ...@@ -11,7 +11,8 @@
#include "cblas.h" #include "cblas.h"
#include <iostream> //#include <iostream>
//using namespace std;
namespace dlib namespace dlib
{ {
...@@ -195,6 +196,29 @@ namespace dlib ...@@ -195,6 +196,29 @@ namespace dlib
return result; return result;
} }
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// Helpers for determining the data pointer, LDA, and incX arguments to BLAS functions.
template <typename T, long NR, long NC, typename MM>
int get_ld (const matrix<T,NR,NC,MM,row_major_layout>& m) { return m.nc(); }
template <typename T, long NR, long NC, typename MM>
int get_ld (const matrix<T,NR,NC,MM,column_major_layout>& m) { return m.nr(); }
// --------
template <typename T, long NR, long NC, typename MM, typename L>
int get_inc (const matrix<T,NR,NC,MM,L>& ) { return 1; }
// --------
template <typename T, long NR, long NC, typename MM, typename L>
const T* get_ptr (const matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -203,10 +227,8 @@ namespace dlib ...@@ -203,10 +227,8 @@ namespace dlib
// simply here so we can build matrix expressions with the DLIB_ADD_BLAS_BINDING marco. // simply here so we can build matrix expressions with the DLIB_ADD_BLAS_BINDING marco.
typedef memory_manager<char>::kernel_1a mm;
// Note that the fact that these are double matrices isn't important, it is just a placeholder in this case. // Note that the fact that these are double matrices isn't important, it is just a placeholder in this case.
extern matrix<double,0,0,mm,row_major_layout> rm; // general matrix with row major order extern matrix<double> m; // general matrix
extern matrix<double,0,0,mm,column_major_layout> cm; // general matrix with column major order
extern matrix<double,1,0> rv; // general row vector extern matrix<double,1,0> rv; // general row vector
extern matrix<double,0,1> cv; // general column vector extern matrix<double,0,1> cv; // general column vector
extern const double s; extern const double s;
...@@ -217,89 +239,89 @@ namespace dlib ...@@ -217,89 +239,89 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, rm*rm) DLIB_ADD_BLAS_BINDING(m*m)
{ {
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans; const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.nr()); const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc()); const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc()); const int K = static_cast<int>(src.lhs.nc());
const T* A = &src.lhs(0,0); const T* A = get_ptr(src.lhs);
const int lda = src.lhs.nc(); const int lda = get_ld(src.lhs);
const T* B = &src.rhs(0,0); const T* B = get_ptr(src.rhs);
const int ldb = src.rhs.nc(); const int ldb = get_ld(src.rhs);
const T beta = static_cast<T>(add_to?1:0); const T beta = static_cast<T>(add_to?1:0);
T* C = &dest(0,0); T* C = get_ptr(dest);
const int ldc = src.nc(); const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, trans(rm)*rm) DLIB_ADD_BLAS_BINDING(trans(m)*m)
{ {
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans; const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.nr()); const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc()); const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc()); const int K = static_cast<int>(src.lhs.nc());
const T* A = &src.lhs.m(0,0); const T* A = get_ptr(src.lhs.m);
const int lda = src.lhs.m.nc(); const int lda = get_ld(src.lhs.m);
const T* B = &src.rhs(0,0); const T* B = get_ptr(src.rhs);
const int ldb = src.rhs.nc(); const int ldb = get_ld(src.rhs);
const T beta = static_cast<T>(add_to?1:0); const T beta = static_cast<T>(add_to?1:0);
T* C = &dest(0,0); T* C = get_ptr(dest);
const int ldc = src.nc(); const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, rm*trans(rm)) DLIB_ADD_BLAS_BINDING(m*trans(m))
{ {
//cout << "BLAS" << endl; //cout << "BLAS: m*trans(m)" << endl;
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasTrans; const CBLAS_TRANSPOSE TransB = CblasTrans;
const int M = static_cast<int>(src.nr()); const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc()); const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc()); const int K = static_cast<int>(src.lhs.nc());
const T* A = &src.lhs(0,0); const T* A = get_ptr(src.lhs);
const int lda = src.lhs.nc(); const int lda = get_ld(src.lhs);
const T* B = &src.rhs.m(0,0); const T* B = get_ptr(src.rhs.m);
const int ldb = src.rhs.m.nc(); const int ldb = get_ld(src.rhs.m);
const T beta = static_cast<T>(add_to?1:0); const T beta = static_cast<T>(add_to?1:0);
T* C = &dest(0,0); T* C = get_ptr(dest);
const int ldc = src.nc(); const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, trans(rm)*trans(rm)) DLIB_ADD_BLAS_BINDING(trans(m)*trans(m))
{ {
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasTrans; const CBLAS_TRANSPOSE TransB = CblasTrans;
const int M = static_cast<int>(src.nr()); const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc()); const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc()); const int K = static_cast<int>(src.lhs.nc());
const T* A = &src.lhs.m(0,0); const T* A = get_ptr(src.lhs.m);
const int lda = src.lhs.m.nc(); const int lda = get_ld(src.lhs.m);
const T* B = &src.rhs.m(0,0); const T* B = get_ptr(src.rhs.m);
const int ldb = src.rhs.m.nc(); const int ldb = get_ld(src.rhs.m);
const T beta = static_cast<T>(add_to?1:0); const T beta = static_cast<T>(add_to?1:0);
T* C = &dest(0,0); T* C = get_ptr(dest);
const int ldc = src.nc(); const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
...@@ -310,88 +332,88 @@ namespace dlib ...@@ -310,88 +332,88 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, rm*cv) DLIB_ADD_BLAS_BINDING(m*cv)
{ {
//cout << "BLAS: rm*cv" << endl; //cout << "BLAS: m*cv" << endl;
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const int M = static_cast<int>(src.lhs.nr()); const int M = static_cast<int>(src.lhs.nr());
const int N = static_cast<int>(src.lhs.nc()); const int N = static_cast<int>(src.lhs.nc());
const T* A = &src.lhs(0,0); const T* A = get_ptr(src.lhs);
const int lda = src.lhs.nc(); const int lda = get_ld(src.lhs);
const T* X = &src.rhs(0,0); const T* X = get_ptr(src.rhs);
const int incX = 1; const int incX = get_inc(src.rhs);
const T beta = static_cast<T>(add_to?1:0); const T beta = static_cast<T>(add_to?1:0);
T* Y = &dest(0,0); T* Y = get_ptr(dest);
const int incY = 1; const int incY = get_inc(dest);
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, rv*rm) DLIB_ADD_BLAS_BINDING(rv*m)
{ {
// Note that rv*rm is the same as trans(rm)*trans(rv) // Note that rv*m is the same as trans(m)*trans(rv)
//cout << "BLAS: rv*rm" << endl; //cout << "BLAS: rv*m" << endl;
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
const int M = static_cast<int>(src.rhs.nr()); const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc()); const int N = static_cast<int>(src.rhs.nc());
const T* A = &src.rhs(0,0); const T* A = get_ptr(src.rhs);
const int lda = src.rhs.nc(); const int lda = get_ld(src.rhs);
const T* X = &src.lhs(0,0); const T* X = get_ptr(src.lhs);
const int incX = 1; const int incX = get_inc(src.lhs);
const T beta = static_cast<T>(add_to?1:0); const T beta = static_cast<T>(add_to?1:0);
T* Y = &dest(0,0); T* Y = get_ptr(dest);
const int incY = 1; const int incY = get_inc(dest);
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, trans(cv)*rm) DLIB_ADD_BLAS_BINDING(trans(cv)*m)
{ {
// Note that trans(cv)*rm is the same as trans(rm)*cv // Note that trans(cv)*m is the same as trans(m)*cv
//cout << "BLAS: trans(cv)*rm" << endl; //cout << "BLAS: trans(cv)*m" << endl;
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
const int M = static_cast<int>(src.rhs.nr()); const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc()); const int N = static_cast<int>(src.rhs.nc());
const T* A = &src.rhs(0,0); const T* A = get_ptr(src.rhs);
const int lda = src.rhs.nc(); const int lda = get_ld(src.rhs);
const T* X = &src.lhs.m(0,0); const T* X = get_ptr(src.lhs.m);
const int incX = 1; const int incX = get_inc(src.lhs.m);
const T beta = static_cast<T>(add_to?1:0); const T beta = static_cast<T>(add_to?1:0);
T* Y = &dest(0,0); T* Y = get_ptr(dest);
const int incY = 1; const int incY = get_inc(dest);
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, rm*trans(rv)) DLIB_ADD_BLAS_BINDING(m*trans(rv))
{ {
//cout << "BLAS: rm*trans(rv)" << endl; //cout << "BLAS: m*trans(rv)" << endl;
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const int M = static_cast<int>(src.lhs.nr()); const int M = static_cast<int>(src.lhs.nr());
const int N = static_cast<int>(src.lhs.nc()); const int N = static_cast<int>(src.lhs.nc());
const T* A = &src.lhs(0,0); const T* A = get_ptr(src.lhs);
const int lda = src.lhs.nc(); const int lda = get_ld(src.lhs);
const T* X = &src.rhs.m(0,0); const T* X = get_ptr(src.rhs.m);
const int incX = 1; const int incX = get_inc(src.rhs.m);
const T beta = static_cast<T>(add_to?1:0); const T beta = static_cast<T>(add_to?1:0);
T* Y = &dest(0,0); T* Y = get_ptr(dest);
const int incY = 1; const int incY = get_inc(dest);
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
...@@ -400,88 +422,88 @@ namespace dlib ...@@ -400,88 +422,88 @@ namespace dlib
// -------------------------------------- // --------------------------------------
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, trans(rm)*cv) DLIB_ADD_BLAS_BINDING(trans(m)*cv)
{ {
//cout << "BLAS: trans(rm)*cv" << endl; //cout << "BLAS: trans(m)*cv" << endl;
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
const int M = static_cast<int>(src.lhs.m.nr()); const int M = static_cast<int>(src.lhs.m.nr());
const int N = static_cast<int>(src.lhs.m.nc()); const int N = static_cast<int>(src.lhs.m.nc());
const T* A = &src.lhs.m(0,0); const T* A = get_ptr(src.lhs.m);
const int lda = src.lhs.m.nc(); const int lda = get_ld(src.lhs.m);
const T* X = &src.rhs(0,0); const T* X = get_ptr(src.rhs);
const int incX = 1; const int incX = get_inc(src.rhs);
const T beta = static_cast<T>(add_to?1:0); const T beta = static_cast<T>(add_to?1:0);
T* Y = &dest(0,0); T* Y = get_ptr(dest);
const int incY = 1; const int incY = get_inc(dest);
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, rv*trans(rm)) DLIB_ADD_BLAS_BINDING(rv*trans(m))
{ {
// Note that rv*trans(rm) is the same as rm*trans(rv) // Note that rv*trans(m) is the same as m*trans(rv)
//cout << "BLAS: rv*trans(rm)" << endl; //cout << "BLAS: rv*trans(m)" << endl;
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const int M = static_cast<int>(src.rhs.m.nr()); const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc()); const int N = static_cast<int>(src.rhs.m.nc());
const T* A = &src.rhs.m(0,0); const T* A = get_ptr(src.rhs.m);
const int lda = src.rhs.m.nc(); const int lda = get_ld(src.rhs.m);
const T* X = &src.lhs(0,0); const T* X = get_ptr(src.lhs);
const int incX = 1; const int incX = get_inc(src.lhs);
const T beta = static_cast<T>(add_to?1:0); const T beta = static_cast<T>(add_to?1:0);
T* Y = &dest(0,0); T* Y = get_ptr(dest);
const int incY = 1; const int incY = get_inc(dest);
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, trans(cv)*trans(rm)) DLIB_ADD_BLAS_BINDING(trans(cv)*trans(m))
{ {
// Note that trans(cv)*trans(rm) is the same as rm*cv // Note that trans(cv)*trans(m) is the same as m*cv
//cout << "BLAS: trans(cv)*trans(rm)" << endl; //cout << "BLAS: trans(cv)*trans(m)" << endl;
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const int M = static_cast<int>(src.rhs.m.nr()); const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc()); const int N = static_cast<int>(src.rhs.m.nc());
const T* A = &src.rhs.m(0,0); const T* A = get_ptr(src.rhs.m);
const int lda = src.rhs.m.nc(); const int lda = get_ld(src.rhs.m);
const T* X = &src.lhs.m(0,0); const T* X = get_ptr(src.lhs.m);
const int incX = 1; const int incX = get_inc(src.lhs.m);
const T beta = static_cast<T>(add_to?1:0); const T beta = static_cast<T>(add_to?1:0);
T* Y = &dest(0,0); T* Y = get_ptr(dest);
const int incY = 1; const int incY = get_inc(dest);
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, trans(rm)*trans(rv)) DLIB_ADD_BLAS_BINDING(trans(m)*trans(rv))
{ {
//cout << "BLAS: trans(rm)*trans(rv)" << endl; //cout << "BLAS: trans(m)*trans(rv)" << endl;
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
const int M = static_cast<int>(src.lhs.m.nr()); const int M = static_cast<int>(src.lhs.m.nr());
const int N = static_cast<int>(src.lhs.m.nc()); const int N = static_cast<int>(src.lhs.m.nc());
const T* A = &src.lhs.m(0,0); const T* A = get_ptr(src.lhs.m);
const int lda = src.lhs.m.nc(); const int lda = get_ld(src.lhs.m);
const T* X = &src.rhs.m(0,0); const T* X = get_ptr(src.rhs.m);
const int incX = 1; const int incX = get_inc(src.rhs.m);
const T beta = static_cast<T>(add_to?1:0); const T beta = static_cast<T>(add_to?1:0);
T* Y = &dest(0,0); T* Y = get_ptr(dest);
const int incY = 1; const int incY = get_inc(dest);
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment