Commit 4f8ff2c3 authored by Davis King's avatar Davis King

- There were a few expressions that should have been bound to BLAS

    calls but weren't getting properly bound.  This has now been fixed.
  - There were a few cases where the code wouldn't compile when using
    complex numbers.  There was also a runtime bug that triggered when
    a rank 1 update was performed where one of the vectors was conjugated
    and two or more transposes were used in certain positions.  This bug
    caused the wrong output to be computed if the BLAS bindings were used.
    Both of these bugs have been fixed.
  - Added hooks for the blas binding counters that are used by the
    new blas_bindings regression tests.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403287
parent 899a3c63
......@@ -176,7 +176,7 @@ namespace dlib
// a temporary 1x1 matrix so that the expression will encounter
// all the overloads of matrix_assign() and have the chance to
// go through any applicable optimizations.
matrix<type,1,1> temp(ref());
matrix<type,1,1,mem_manager_type,layout_type> temp(ref());
return temp(0);
}
......@@ -266,8 +266,15 @@ namespace dlib
const static long NR = LHS::NR;
const static long NC = RHS::NC;
#ifdef DLIB_USE_BLAS
// if there are BLAS functions to be called then we want to make sure we
// always evaluate any complex expressions so that the BLAS bindings can happen.
const static bool lhs_is_costly = (LHS::cost > 1)&&(RHS::NC != 1 || LHS::cost >= 10000);
const static bool rhs_is_costly = (RHS::cost > 1)&&(LHS::NR != 1 || RHS::cost >= 10000);
#else
const static bool lhs_is_costly = (LHS::cost > 4)&&(RHS::NC != 1);
const static bool rhs_is_costly = (RHS::cost > 4)&&(LHS::NR != 1);
#endif
// Note that if we decide that one of the matrices is too costly we will evaluate it
// into a temporary. Doing this resets its cost back to 1.
......
......@@ -216,7 +216,7 @@ namespace dlib
{
if (add_to)
{
if (alpha == 1)
if (alpha == static_cast<typename EXP2::type>(1))
{
for (long c = 0; c < src.nc(); ++c)
{
......@@ -226,7 +226,7 @@ namespace dlib
}
}
}
else if (alpha == -1)
else if (alpha == static_cast<typename EXP2::type>(-1))
{
for (long c = 0; c < src.nc(); ++c)
{
......@@ -249,7 +249,7 @@ namespace dlib
}
else
{
if (alpha == 1)
if (alpha == static_cast<typename EXP2::type>(1))
{
for (long c = 0; c < src.nc(); ++c)
{
......
......@@ -20,6 +20,23 @@ namespace dlib
namespace blas_bindings
{
#ifdef DLIB_TEST_BLAS_BINDINGS
int& counter_gemm();
int& counter_gemv();
int& counter_ger();
int& counter_dot();
#define DLIB_TEST_BLAS_BINDING_GEMM ++counter_gemm();
#define DLIB_TEST_BLAS_BINDING_GEMV ++counter_gemv();
#define DLIB_TEST_BLAS_BINDING_GER ++counter_ger();
#define DLIB_TEST_BLAS_BINDING_DOT ++counter_dot();
#else
#define DLIB_TEST_BLAS_BINDING_GEMM
#define DLIB_TEST_BLAS_BINDING_GEMV
#define DLIB_TEST_BLAS_BINDING_GER
#define DLIB_TEST_BLAS_BINDING_DOT
#endif
extern "C"
{
// Here we declare the prototypes for the CBLAS calls used by the BLAS bindings below
......@@ -108,6 +125,7 @@ namespace dlib
const int lda, const float *B, const int ldb,
const float beta, float *C, const int ldc)
{
DLIB_TEST_BLAS_BINDING_GEMM;
cblas_sgemm( Order, TransA, TransB, M, N,
K, alpha, A, lda, B, ldb, beta, C, ldc);
}
......@@ -118,6 +136,7 @@ namespace dlib
const int lda, const double *B, const int ldb,
const double beta, double *C, const int ldc)
{
DLIB_TEST_BLAS_BINDING_GEMM;
cblas_dgemm( Order, TransA, TransB, M, N,
K, alpha, A, lda, B, ldb, beta, C, ldc);
}
......@@ -128,6 +147,7 @@ namespace dlib
const int lda, const std::complex<float> *B, const int ldb,
const std::complex<float>& beta, std::complex<float> *C, const int ldc)
{
DLIB_TEST_BLAS_BINDING_GEMM;
cblas_cgemm( Order, TransA, TransB, M, N,
K, &alpha, A, lda, B, ldb, &beta, C, ldc);
}
......@@ -138,6 +158,7 @@ namespace dlib
const int lda, const std::complex<double> *B, const int ldb,
const std::complex<double>& beta, std::complex<double> *C, const int ldc)
{
DLIB_TEST_BLAS_BINDING_GEMM;
cblas_zgemm( Order, TransA, TransB, M, N,
K, &alpha, A, lda, B, ldb, &beta, C, ldc);
}
......@@ -150,6 +171,7 @@ namespace dlib
const float *X, const int incX, const float beta,
float *Y, const int incY)
{
DLIB_TEST_BLAS_BINDING_GEMV;
cblas_sgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
}
......@@ -159,6 +181,7 @@ namespace dlib
const double *X, const int incX, const double beta,
double *Y, const int incY)
{
DLIB_TEST_BLAS_BINDING_GEMV;
cblas_dgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
}
......@@ -168,6 +191,7 @@ namespace dlib
const std::complex<float> *X, const int incX, const std::complex<float>& beta,
std::complex<float> *Y, const int incY)
{
DLIB_TEST_BLAS_BINDING_GEMV;
cblas_cgemv(order, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY);
}
......@@ -177,6 +201,7 @@ namespace dlib
const std::complex<double> *X, const int incX, const std::complex<double>& beta,
std::complex<double> *Y, const int incY)
{
DLIB_TEST_BLAS_BINDING_GEMV;
cblas_zgemv(order, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY);
}
......@@ -186,6 +211,7 @@ namespace dlib
const std::complex<float>& alpha, const std::complex<float> *X, const int incX,
const std::complex<float> *Y, const int incY, std::complex<float> *A, const int lda)
{
DLIB_TEST_BLAS_BINDING_GER;
cblas_cgeru (order, M, N, &alpha, X, incX, Y, incY, A, lda);
}
......@@ -193,6 +219,7 @@ namespace dlib
const std::complex<double>& alpha, const std::complex<double> *X, const int incX,
const std::complex<double> *Y, const int incY, std::complex<double> *A, const int lda)
{
DLIB_TEST_BLAS_BINDING_GER;
cblas_zgeru (order, M, N, &alpha, X, incX, Y, incY, A, lda);
}
......@@ -200,6 +227,7 @@ namespace dlib
const float alpha, const float *X, const int incX,
const float *Y, const int incY, float *A, const int lda)
{
DLIB_TEST_BLAS_BINDING_GER;
cblas_sger (order, M, N, alpha, X, incX, Y, incY, A, lda);
}
......@@ -207,6 +235,7 @@ namespace dlib
const double alpha, const double *X, const int incX,
const double *Y, const int incY, double *A, const int lda)
{
DLIB_TEST_BLAS_BINDING_GER;
cblas_dger (order, M, N, alpha, X, incX, Y, incY, A, lda);
}
......@@ -216,6 +245,7 @@ namespace dlib
const std::complex<float>& alpha, const std::complex<float> *X, const int incX,
const std::complex<float> *Y, const int incY, std::complex<float> *A, const int lda)
{
DLIB_TEST_BLAS_BINDING_GER;
cblas_cgerc (order, M, N, &alpha, X, incX, Y, incY, A, lda);
}
......@@ -223,6 +253,7 @@ namespace dlib
const std::complex<double>& alpha, const std::complex<double> *X, const int incX,
const std::complex<double> *Y, const int incY, std::complex<double> *A, const int lda)
{
DLIB_TEST_BLAS_BINDING_GER;
cblas_zgerc (order, M, N, &alpha, X, incX, Y, incY, A, lda);
}
......@@ -231,18 +262,21 @@ namespace dlib
inline float cblas_dot(const int N, const float *X, const int incX,
const float *Y, const int incY)
{
DLIB_TEST_BLAS_BINDING_DOT;
return cblas_sdot(N, X, incX, Y, incY);
}
inline double cblas_dot(const int N, const double *X, const int incX,
const double *Y, const int incY)
{
DLIB_TEST_BLAS_BINDING_DOT;
return cblas_ddot(N, X, incX, Y, incY);
}
inline std::complex<float> cblas_dot(const int N, const std::complex<float> *X, const int incX,
const std::complex<float> *Y, const int incY)
{
DLIB_TEST_BLAS_BINDING_DOT;
std::complex<float> result;
cblas_cdotu_sub(N, X, incX, Y, incY, &result);
return result;
......@@ -251,6 +285,7 @@ namespace dlib
inline std::complex<double> cblas_dot(const int N, const std::complex<double> *X, const int incX,
const std::complex<double> *Y, const int incY)
{
DLIB_TEST_BLAS_BINDING_DOT;
std::complex<double> result;
cblas_zdotu_sub(N, X, incX, Y, incY, &result);
return result;
......@@ -261,6 +296,7 @@ namespace dlib
inline std::complex<float> cblas_dotc(const int N, const std::complex<float> *X, const int incX,
const std::complex<float> *Y, const int incY)
{
DLIB_TEST_BLAS_BINDING_DOT;
std::complex<float> result;
cblas_cdotc_sub(N, X, incX, Y, incY, &result);
return result;
......@@ -269,6 +305,7 @@ namespace dlib
inline std::complex<double> cblas_dotc(const int N, const std::complex<double> *X, const int incX,
const std::complex<double> *Y, const int incY)
{
DLIB_TEST_BLAS_BINDING_DOT;
std::complex<double> result;
cblas_zdotc_sub(N, X, incX, Y, incY, &result);
return result;
......@@ -568,6 +605,34 @@ namespace dlib
// --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(conj(m))*trans(m))
{
//cout << "BLAS GEMM: trans(conj(m))*trans(m)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasConjTrans;
const CBLAS_TRANSPOSE TransB = CblasTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const T* A = get_ptr(src.lhs.m);
const int lda = get_ld(src.lhs.m);
const T* B = get_ptr(src.rhs.m);
const int ldb = get_ld(src.rhs.m);
const T beta = static_cast<T>(add_to?1:0);
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
matrix_assign_default(dest, trans(src), alpha, add_to);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(m*trans(conj(m)))
{
//cout << "BLAS GEMM: m*trans(conj(m))" << endl;
......@@ -595,6 +660,32 @@ namespace dlib
// --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(m)*trans(conj(m)))
{
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasConjTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const T* A = get_ptr(src.lhs.m);
const int lda = get_ld(src.lhs.m);
const T* B = get_ptr(src.rhs.m);
const int ldb = get_ld(src.rhs.m);
const T beta = static_cast<T>(add_to?1:0);
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
matrix_assign_default(dest, trans(src), alpha, add_to);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(conj(m))*trans(conj(m)))
{
//cout << "BLAS GEMM: trans(conj(m))*trans(conj(m))" << endl;
......@@ -1017,6 +1108,7 @@ namespace dlib
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
/*
DLIB_ADD_BLAS_BINDING(cv*conj(rv))
{
//cout << "BLAS GERC: cv*conj(rv)" << endl;
......@@ -1040,6 +1132,7 @@ namespace dlib
else
cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
*/
// --------------------------------------
......@@ -1055,16 +1148,22 @@ namespace dlib
const T* Y = get_ptr(src.rhs.m);
const int incY = get_inc(src.rhs.m);
if (add_to == false)
zero_matrix(dest);
T* A = get_ptr(dest);
const int lda = get_ld(dest);
if (transpose == false)
{
T* A = get_ptr(dest);
const int lda = get_ld(dest);
if (add_to == false)
zero_matrix(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
}
else
cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
{
matrix_assign_default(dest,trans(src),alpha,add_to);
//cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
}
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -1081,20 +1180,27 @@ namespace dlib
const T* Y = get_ptr(src.rhs.m);
const int incY = get_inc(src.rhs.m);
if (add_to == false)
zero_matrix(dest);
T* A = get_ptr(dest);
const int lda = get_ld(dest);
if (transpose == false)
{
if (add_to == false)
zero_matrix(dest);
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
}
else
cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
{
matrix_assign_default(dest,trans(src),alpha,add_to);
//cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
}
} DLIB_END_BLAS_BINDING
// --------------------------------------
/*
DLIB_ADD_BLAS_BINDING(trans(rv)*conj(rv))
{
//cout << "BLAS GERC: trans(rv)*conj(rv)" << endl;
......@@ -1118,6 +1224,7 @@ namespace dlib
else
cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
*/
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment