Commit 51dcdfad authored by Davis King's avatar Davis King

Added the last few BLAS bindings and adjusted a few of the

constants related to matrix expression cost and when the BLAS
bindings can become active.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402778
parent 715ca0da
......@@ -28,7 +28,7 @@ namespace dlib
struct is_small_matrix { static const bool value = false; };
template < typename EXP >
struct is_small_matrix<EXP, typename enable_if_c<EXP::NR>=1 && EXP::NC>=1 &&
EXP::NR<=100 && EXP::NC<=100 && (EXP::cost < 70)>::type> { static const bool value = true; };
EXP::NR<=30 && EXP::NC<=30 && (EXP::cost < 70)>::type> { static const bool value = true; };
}
// ----------------------------------------------------------------------------------------
......
......@@ -413,6 +413,78 @@ namespace dlib
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
// --------------------------------------
// --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(conj(m))*m)
{
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasConjTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const T* A = get_ptr(src.lhs.m);
const int lda = get_ld(src.lhs.m);
const T* B = get_ptr(src.rhs);
const int ldb = get_ld(src.rhs);
const T beta = static_cast<T>(add_to?1:0);
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(m*trans(conj(m)))
{
//cout << "BLAS: m*trans(conj(m))" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasConjTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const T* A = get_ptr(src.lhs);
const int lda = get_ld(src.lhs);
const T* B = get_ptr(src.rhs.m);
const int ldb = get_ld(src.rhs.m);
const T beta = static_cast<T>(add_to?1:0);
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(conj(m))*trans(conj(m)))
{
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasConjTrans;
const CBLAS_TRANSPOSE TransB = CblasConjTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const T* A = get_ptr(src.lhs.m);
const int lda = get_ld(src.lhs.m);
const T* B = get_ptr(src.rhs.m);
const int ldb = get_ld(src.rhs.m);
const T beta = static_cast<T>(add_to?1:0);
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// GEMV overloads
......@@ -604,6 +676,98 @@ namespace dlib
} DLIB_END_BLAS_BINDING
// --------------------------------------
// --------------------------------------
// --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(conj(m))*cv)
{
//cout << "BLAS: trans(conj(m))*cv" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasConjTrans;
const int M = static_cast<int>(src.lhs.m.nr());
const int N = static_cast<int>(src.lhs.m.nc());
const T* A = get_ptr(src.lhs.m);
const int lda = get_ld(src.lhs.m);
const T* X = get_ptr(src.rhs);
const int incX = get_inc(src.rhs);
const T beta = static_cast<T>(add_to?1:0);
T* Y = get_ptr(dest);
const int incY = get_inc(dest);
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(rv*trans(conj(m)))
{
// Note that rv*trans(m) is the same as m*trans(rv)
//cout << "BLAS: rv*trans(conj(m))" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasConjTrans;
const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc());
const T* A = get_ptr(src.rhs.m);
const int lda = get_ld(src.rhs.m);
const T* X = get_ptr(src.lhs);
const int incX = get_inc(src.lhs);
const T beta = static_cast<T>(add_to?1:0);
T* Y = get_ptr(dest);
const int incY = get_inc(dest);
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(cv)*trans(conj(m)))
{
// Note that trans(cv)*trans(m) is the same as m*cv
//cout << "BLAS: trans(cv)*trans(m)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasConjTrans;
const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc());
const T* A = get_ptr(src.rhs.m);
const int lda = get_ld(src.rhs.m);
const T* X = get_ptr(src.lhs.m);
const int incX = get_inc(src.lhs.m);
const T beta = static_cast<T>(add_to?1:0);
T* Y = get_ptr(dest);
const int incY = get_inc(dest);
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(conj(m))*trans(rv))
{
//cout << "BLAS: trans(conj(m))*trans(rv)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasConjTrans;
const int M = static_cast<int>(src.lhs.m.nr());
const int N = static_cast<int>(src.lhs.m.nc());
const T* A = get_ptr(src.lhs.m);
const int lda = get_ld(src.lhs.m);
const T* X = get_ptr(src.rhs.m);
const int incX = get_inc(src.rhs.m);
const T beta = static_cast<T>(add_to?1:0);
T* Y = get_ptr(dest);
const int incY = get_inc(dest);
cblas_gemv(Order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
} DLIB_END_BLAS_BINDING
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......
......@@ -36,7 +36,7 @@ namespace dlib
// ----------------------------------------------------------------------------------------
DLIB_MATRIX_SIMPLE_STD_FUNCTION(abs,3)
DLIB_MATRIX_SIMPLE_STD_FUNCTION(abs,7)
DLIB_MATRIX_SIMPLE_STD_FUNCTION(sqrt,7)
DLIB_MATRIX_SIMPLE_STD_FUNCTION(log,7)
DLIB_MATRIX_SIMPLE_STD_FUNCTION(log10,7)
......@@ -44,8 +44,8 @@ DLIB_MATRIX_SIMPLE_STD_FUNCTION(exp,7)
DLIB_MATRIX_SIMPLE_STD_FUNCTION(conj,1)
DLIB_MATRIX_SIMPLE_STD_FUNCTION(ceil,2)
DLIB_MATRIX_SIMPLE_STD_FUNCTION(floor,20)
DLIB_MATRIX_SIMPLE_STD_FUNCTION(ceil,7)
DLIB_MATRIX_SIMPLE_STD_FUNCTION(floor,7)
DLIB_MATRIX_SIMPLE_STD_FUNCTION(sin,7)
DLIB_MATRIX_SIMPLE_STD_FUNCTION(cos,7)
......@@ -91,7 +91,7 @@ DLIB_MATRIX_SIMPLE_STD_FUNCTION(atan,7)
template <typename EXP>
struct op : has_nondestructive_aliasing, preserves_dimensions<EXP>
{
const static long cost = EXP::cost+1;
const static long cost = EXP::cost+7;
typedef typename EXP::type type;
template <typename M, typename T>
static type apply ( const M& m, const T& eps, long r, long c)
......@@ -138,7 +138,7 @@ DLIB_MATRIX_SIMPLE_STD_FUNCTION(atan,7)
template <typename EXP>
struct op : has_nondestructive_aliasing, preserves_dimensions<EXP>
{
const static long cost = EXP::cost+2;
const static long cost = EXP::cost+7;
typedef typename EXP::type type;
template <typename M>
static type apply ( const M& m, long r, long c)
......@@ -166,7 +166,7 @@ DLIB_MATRIX_SIMPLE_STD_FUNCTION(atan,7)
template <typename EXP>
struct op : has_nondestructive_aliasing, preserves_dimensions<EXP>
{
const static long cost = EXP::cost+1;
const static long cost = EXP::cost+6;
typedef typename EXP::type type;
template <typename M>
static type apply ( const M& m, long r, long c)
......@@ -194,7 +194,7 @@ DLIB_MATRIX_SIMPLE_STD_FUNCTION(atan,7)
template <typename EXP>
struct op : has_nondestructive_aliasing, preserves_dimensions<EXP>
{
const static long cost = EXP::cost+4;
const static long cost = EXP::cost+7;
typedef typename EXP::type type;
template <typename M, typename S>
static type apply ( const M& m, const S& s, long r, long c)
......@@ -227,7 +227,7 @@ DLIB_MATRIX_SIMPLE_STD_FUNCTION(atan,7)
template <typename EXP>
struct op : has_nondestructive_aliasing, preserves_dimensions<EXP>
{
const static long cost = EXP::cost+2;
const static long cost = EXP::cost+6;
typedef typename EXP::type type;
template <typename M>
static type apply ( const M& m, long r, long c)
......@@ -264,7 +264,7 @@ DLIB_MATRIX_SIMPLE_STD_FUNCTION(atan,7)
template <typename EXP>
struct op : has_nondestructive_aliasing, preserves_dimensions<EXP>
{
const static long cost = EXP::cost+1;
const static long cost = EXP::cost+5;
typedef typename EXP::type type;
template <typename M>
static type apply ( const M& m, const type& s, long r, long c)
......@@ -303,7 +303,7 @@ DLIB_MATRIX_SIMPLE_STD_FUNCTION(atan,7)
template <typename EXP, typename enabled = void>
struct op : has_nondestructive_aliasing, preserves_dimensions<EXP>
{
const static long cost = EXP::cost+3;
const static long cost = EXP::cost+7;
typedef typename EXP::type type;
template <typename M>
static type apply ( const M& m, long r, long c)
......@@ -388,7 +388,7 @@ DLIB_MATRIX_SIMPLE_STD_FUNCTION(atan,7)
template <typename EXP>
struct op : has_nondestructive_aliasing, preserves_dimensions<EXP>
{
const static long cost = EXP::cost+2;
const static long cost = EXP::cost+6;
typedef typename EXP::type::value_type type;
template <typename M>
static type apply ( const M& m, long r, long c)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment