Commit 88dea115 authored by Davis King's avatar Davis King

Added GEMV BLAS bindings.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402766
parent 7a59836f
...@@ -8,7 +8,10 @@ ...@@ -8,7 +8,10 @@
#include "matrix/matrix_subexp.h" #include "matrix/matrix_subexp.h"
#include "matrix/matrix_math_functions.h" #include "matrix/matrix_math_functions.h"
#include "matrix/matrix_assign.h" #include "matrix/matrix_assign.h"
#ifdef DLIB_USE_BLAS
#include "matrix/matrix_blas_bindings.h" #include "matrix/matrix_blas_bindings.h"
#endif
#endif // DLIB_MATRIx_HEADER #endif // DLIB_MATRIx_HEADER
......
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
#ifndef DLIB_MATRIx_BLAS_BINDINGS_ #ifndef DLIB_MATRIx_BLAS_BINDINGS_
#define DLIB_MATRIx_BLAS_BINDINGS_ #define DLIB_MATRIx_BLAS_BINDINGS_
#ifndef DLIB_USE_BLAS
#error "DLIB_USE_BLAS should be defined if you want to use the BLAS bindings"
#endif
#include "matrix_assign.h" #include "matrix_assign.h"
#ifdef DLIB_USE_BLAS
#include "cblas.h" #include "cblas.h"
#endif
#include <iostream> #include <iostream>
...@@ -18,8 +20,6 @@ namespace dlib ...@@ -18,8 +20,6 @@ namespace dlib
namespace blas_bindings namespace blas_bindings
{ {
#ifdef DLIB_USE_BLAS
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -211,6 +211,10 @@ namespace dlib ...@@ -211,6 +211,10 @@ namespace dlib
extern matrix<double,0,1> cv; // general column vector extern matrix<double,0,1> cv; // general column vector
extern const double s; extern const double s;
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// GEMM overloads
// ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, rm*rm) DLIB_ADD_BLAS_BINDING(row_major_layout, rm*rm)
...@@ -255,7 +259,265 @@ namespace dlib ...@@ -255,7 +259,265 @@ namespace dlib
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
#endif // DLIB_USE_BLAS // --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, rm*trans(rm))
{
//cout << "BLAS" << endl;
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const T* A = &src.lhs(0,0);
const int lda = src.lhs.nc();
const T* B = &src.rhs.m(0,0);
const int ldb = src.rhs.m.nc();
const T beta = static_cast<T>(add_to?1:0);
T* C = &dest(0,0);
const int ldc = src.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, trans(rm)*trans(rm))
{
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const T* A = &src.lhs.m(0,0);
const int lda = src.lhs.m.nc();
const T* B = &src.rhs.m(0,0);
const int ldb = src.rhs.m.nc();
const T beta = static_cast<T>(add_to?1:0);
T* C = &dest(0,0);
const int ldc = src.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// GEMV overloads
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, rm*cv)
{
//cout << "BLAS: rm*cv" << endl;
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const int M = static_cast<int>(src.lhs.nr());
const int N = static_cast<int>(src.lhs.nc());
const T* A = &src.lhs(0,0);
const int lda = src.lhs.nc();
const T* X = &src.rhs(0,0);
const int incX = 1;
const T beta = static_cast<T>(add_to?1:0);
T* Y = &dest(0,0);
const int incY = 1;
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, rv*rm)
{
// Note that rv*rm is the same as trans(rm)*trans(rv)
//cout << "BLAS: rv*rm" << endl;
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const T* A = &src.rhs(0,0);
const int lda = src.rhs.nc();
const T* X = &src.lhs(0,0);
const int incX = 1;
const T beta = static_cast<T>(add_to?1:0);
T* Y = &dest(0,0);
const int incY = 1;
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, trans(cv)*rm)
{
// Note that trans(cv)*rm is the same as trans(rm)*cv
//cout << "BLAS: trans(cv)*rm" << endl;
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const T* A = &src.rhs(0,0);
const int lda = src.rhs.nc();
const T* X = &src.lhs.m(0,0);
const int incX = 1;
const T beta = static_cast<T>(add_to?1:0);
T* Y = &dest(0,0);
const int incY = 1;
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, rm*trans(rv))
{
//cout << "BLAS: rm*trans(rv)" << endl;
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const int M = static_cast<int>(src.lhs.nr());
const int N = static_cast<int>(src.lhs.nc());
const T* A = &src.lhs(0,0);
const int lda = src.lhs.nc();
const T* X = &src.rhs.m(0,0);
const int incX = 1;
const T beta = static_cast<T>(add_to?1:0);
T* Y = &dest(0,0);
const int incY = 1;
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING
// --------------------------------------
// --------------------------------------
// --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, trans(rm)*cv)
{
//cout << "BLAS: trans(rm)*cv" << endl;
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const int M = static_cast<int>(src.lhs.m.nr());
const int N = static_cast<int>(src.lhs.m.nc());
const T* A = &src.lhs.m(0,0);
const int lda = src.lhs.m.nc();
const T* X = &src.rhs(0,0);
const int incX = 1;
const T beta = static_cast<T>(add_to?1:0);
T* Y = &dest(0,0);
const int incY = 1;
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, rv*trans(rm))
{
// Note that rv*trans(rm) is the same as rm*trans(rv)
//cout << "BLAS: rv*trans(rm)" << endl;
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc());
const T* A = &src.rhs.m(0,0);
const int lda = src.rhs.m.nc();
const T* X = &src.lhs(0,0);
const int incX = 1;
const T beta = static_cast<T>(add_to?1:0);
T* Y = &dest(0,0);
const int incY = 1;
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, trans(cv)*trans(rm))
{
// Note that trans(cv)*trans(rm) is the same as rm*cv
//cout << "BLAS: trans(cv)*trans(rm)" << endl;
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc());
const T* A = &src.rhs.m(0,0);
const int lda = src.rhs.m.nc();
const T* X = &src.lhs.m(0,0);
const int incX = 1;
const T beta = static_cast<T>(add_to?1:0);
T* Y = &dest(0,0);
const int incY = 1;
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, trans(rm)*trans(rv))
{
//cout << "BLAS: trans(rm)*trans(rv)" << endl;
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const int M = static_cast<int>(src.lhs.m.nr());
const int N = static_cast<int>(src.lhs.m.nc());
const T* A = &src.lhs.m(0,0);
const int lda = src.lhs.m.nc();
const T* X = &src.rhs.m(0,0);
const int incX = 1;
const T beta = static_cast<T>(add_to?1:0);
T* Y = &dest(0,0);
const int incY = 1;
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING
// --------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// GER overloads
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// GERC overloads
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// DOT overloads
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// DOTC overloads
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment