Commit 5acf84a6 authored by Davis King's avatar Davis King

Fixed the remaining known issues with the BLAS bindings and added

a bunch of related tests to the regression test suite.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402779
parent 51dcdfad
...@@ -47,22 +47,22 @@ namespace dlib ...@@ -47,22 +47,22 @@ namespace dlib
inline void cblas_gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, inline void cblas_gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const std::complex<float> *alpha, const std::complex<float> *A, const int K, const std::complex<float>& alpha, const std::complex<float> *A,
const int lda, const std::complex<float> *B, const int ldb, const int lda, const std::complex<float> *B, const int ldb,
const std::complex<float> *beta, std::complex<float> *C, const int ldc) const std::complex<float>& beta, std::complex<float> *C, const int ldc)
{ {
cblas_cgemm( Order, TransA, TransB, M, N, cblas_cgemm( Order, TransA, TransB, M, N,
K, alpha, A, lda, B, ldb, beta, C, ldc); K, &alpha, A, lda, B, ldb, &beta, C, ldc);
} }
inline void cblas_gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, inline void cblas_gemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const std::complex<double> *alpha, const std::complex<double> *A, const int K, const std::complex<double>& alpha, const std::complex<double> *A,
const int lda, const std::complex<double> *B, const int ldb, const int lda, const std::complex<double> *B, const int ldb,
const std::complex<double> *beta, std::complex<double> *C, const int ldc) const std::complex<double>& beta, std::complex<double> *C, const int ldc)
{ {
cblas_zgemm( Order, TransA, TransB, M, N, cblas_zgemm( Order, TransA, TransB, M, N,
K, alpha, A, lda, B, ldb, beta, C, ldc); K, &alpha, A, lda, B, ldb, &beta, C, ldc);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -87,20 +87,20 @@ namespace dlib ...@@ -87,20 +87,20 @@ namespace dlib
inline void cblas_gemv(const enum CBLAS_ORDER order, inline void cblas_gemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N, const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const std::complex<float> *alpha, const std::complex<float> *A, const int lda, const std::complex<float>& alpha, const std::complex<float> *A, const int lda,
const std::complex<float> *X, const int incX, const std::complex<float> *beta, const std::complex<float> *X, const int incX, const std::complex<float>& beta,
std::complex<float> *Y, const int incY) std::complex<float> *Y, const int incY)
{ {
cblas_cgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); cblas_cgemv(order, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY);
} }
inline void cblas_gemv(const enum CBLAS_ORDER order, inline void cblas_gemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N, const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const std::complex<double> *alpha, const std::complex<double> *A, const int lda, const std::complex<double>& alpha, const std::complex<double> *A, const int lda,
const std::complex<double> *X, const int incX, const std::complex<double> *beta, const std::complex<double> *X, const int incX, const std::complex<double>& beta,
std::complex<double> *Y, const int incY) std::complex<double> *Y, const int incY)
{ {
cblas_zgemv(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY); cblas_zgemv(order, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -323,7 +323,7 @@ namespace dlib ...@@ -323,7 +323,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(m*m) DLIB_ADD_BLAS_BINDING(m*m)
{ {
//cout << "BLAS: m*m" << endl; //cout << "BLAS GEMM: m*m" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
...@@ -347,6 +347,7 @@ namespace dlib ...@@ -347,6 +347,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(trans(m)*m) DLIB_ADD_BLAS_BINDING(trans(m)*m)
{ {
//cout << "BLAS GEMM: trans(m)*m" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
...@@ -370,7 +371,7 @@ namespace dlib ...@@ -370,7 +371,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(m*trans(m)) DLIB_ADD_BLAS_BINDING(m*trans(m))
{ {
//cout << "BLAS: m*trans(m)" << endl; //cout << "BLAS GEMM: m*trans(m)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
...@@ -394,6 +395,7 @@ namespace dlib ...@@ -394,6 +395,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(trans(m)*trans(m)) DLIB_ADD_BLAS_BINDING(trans(m)*trans(m))
{ {
//cout << "BLAS GEMM: trans(m)*trans(m)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
...@@ -419,6 +421,7 @@ namespace dlib ...@@ -419,6 +421,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(trans(conj(m))*m) DLIB_ADD_BLAS_BINDING(trans(conj(m))*m)
{ {
//cout << "BLAS GEMM: trans(conj(m))*m" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasConjTrans; const CBLAS_TRANSPOSE TransA = CblasConjTrans;
...@@ -442,7 +445,7 @@ namespace dlib ...@@ -442,7 +445,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(m*trans(conj(m))) DLIB_ADD_BLAS_BINDING(m*trans(conj(m)))
{ {
//cout << "BLAS: m*trans(conj(m))" << endl; //cout << "BLAS GEMM: m*trans(conj(m))" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
...@@ -466,6 +469,7 @@ namespace dlib ...@@ -466,6 +469,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(trans(conj(m))*trans(conj(m))) DLIB_ADD_BLAS_BINDING(trans(conj(m))*trans(conj(m)))
{ {
//cout << "BLAS GEMM: trans(conj(m))*trans(conj(m))" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasConjTrans; const CBLAS_TRANSPOSE TransA = CblasConjTrans;
...@@ -493,7 +497,7 @@ namespace dlib ...@@ -493,7 +497,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(m*cv) DLIB_ADD_BLAS_BINDING(m*cv)
{ {
//cout << "BLAS: m*cv" << endl; //cout << "BLAS GEMV: m*cv" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
...@@ -517,7 +521,7 @@ namespace dlib ...@@ -517,7 +521,7 @@ namespace dlib
{ {
// Note that rv*m is the same as trans(m)*trans(rv) // Note that rv*m is the same as trans(m)*trans(rv)
//cout << "BLAS: rv*m" << endl; //cout << "BLAS GEMV: rv*m" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
...@@ -541,7 +545,7 @@ namespace dlib ...@@ -541,7 +545,7 @@ namespace dlib
{ {
// Note that trans(cv)*m is the same as trans(m)*cv // Note that trans(cv)*m is the same as trans(m)*cv
//cout << "BLAS: trans(cv)*m" << endl; //cout << "BLAS GEMV: trans(cv)*m" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
...@@ -563,7 +567,7 @@ namespace dlib ...@@ -563,7 +567,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(m*trans(rv)) DLIB_ADD_BLAS_BINDING(m*trans(rv))
{ {
//cout << "BLAS: m*trans(rv)" << endl; //cout << "BLAS GEMV: m*trans(rv)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
...@@ -587,7 +591,7 @@ namespace dlib ...@@ -587,7 +591,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(trans(m)*cv) DLIB_ADD_BLAS_BINDING(trans(m)*cv)
{ {
//cout << "BLAS: trans(m)*cv" << endl; //cout << "BLAS GEMV: trans(m)*cv" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
...@@ -611,7 +615,7 @@ namespace dlib ...@@ -611,7 +615,7 @@ namespace dlib
{ {
// Note that rv*trans(m) is the same as m*trans(rv) // Note that rv*trans(m) is the same as m*trans(rv)
//cout << "BLAS: rv*trans(m)" << endl; //cout << "BLAS GEMV: rv*trans(m)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
...@@ -635,7 +639,7 @@ namespace dlib ...@@ -635,7 +639,7 @@ namespace dlib
{ {
// Note that trans(cv)*trans(m) is the same as m*cv // Note that trans(cv)*trans(m) is the same as m*cv
//cout << "BLAS: trans(cv)*trans(m)" << endl; //cout << "BLAS GEMV: trans(cv)*trans(m)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
...@@ -657,7 +661,7 @@ namespace dlib ...@@ -657,7 +661,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(trans(m)*trans(rv)) DLIB_ADD_BLAS_BINDING(trans(m)*trans(rv))
{ {
//cout << "BLAS: trans(m)*trans(rv)" << endl; //cout << "BLAS GEMV: trans(m)*trans(rv)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
...@@ -679,18 +683,19 @@ namespace dlib ...@@ -679,18 +683,19 @@ namespace dlib
// -------------------------------------- // --------------------------------------
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(conj(m))*cv) DLIB_ADD_BLAS_BINDING(trans(cv)*conj(m))
{ {
//cout << "BLAS: trans(conj(m))*cv" << endl; // Note that trans(cv)*conj(m) == conj(trans(m))*cv
//cout << "BLAS GEMV: trans(cv)*conj(m)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasConjTrans; const CBLAS_TRANSPOSE TransA = CblasConjTrans;
const int M = static_cast<int>(src.lhs.m.nr()); const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.lhs.m.nc()); const int N = static_cast<int>(src.rhs.m.nc());
const T* A = get_ptr(src.lhs.m); const T* A = get_ptr(src.rhs.m);
const int lda = get_ld(src.lhs.m); const int lda = get_ld(src.rhs.m);
const T* X = get_ptr(src.rhs); const T* X = get_ptr(src.lhs.m);
const int incX = get_inc(src.rhs); const int incX = get_inc(src.lhs.m);
const T beta = static_cast<T>(add_to?1:0); const T beta = static_cast<T>(add_to?1:0);
T* Y = get_ptr(dest); T* Y = get_ptr(dest);
...@@ -701,11 +706,10 @@ namespace dlib ...@@ -701,11 +706,10 @@ namespace dlib
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(rv*trans(conj(m))) DLIB_ADD_BLAS_BINDING(rv*conj(m))
{ {
// Note that rv*trans(m) is the same as m*trans(rv) // Note that rv*conj(m) == conj(trans(m))*cv
//cout << "BLAS GEMV: rv*conj(m)" << endl;
//cout << "BLAS: rv*trans(conj(m))" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasConjTrans; const CBLAS_TRANSPOSE TransA = CblasConjTrans;
...@@ -725,20 +729,18 @@ namespace dlib ...@@ -725,20 +729,18 @@ namespace dlib
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(cv)*trans(conj(m))) DLIB_ADD_BLAS_BINDING(trans(conj(m))*cv)
{ {
// Note that trans(cv)*trans(m) is the same as m*cv //cout << "BLAS GEMV: trans(conj(m))*cv" << endl;
//cout << "BLAS: trans(cv)*trans(m)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasConjTrans; const CBLAS_TRANSPOSE TransA = CblasConjTrans;
const int M = static_cast<int>(src.rhs.m.nr()); const int M = static_cast<int>(src.lhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc()); const int N = static_cast<int>(src.lhs.m.nc());
const T* A = get_ptr(src.rhs.m); const T* A = get_ptr(src.lhs.m);
const int lda = get_ld(src.rhs.m); const int lda = get_ld(src.lhs.m);
const T* X = get_ptr(src.lhs.m); const T* X = get_ptr(src.rhs);
const int incX = get_inc(src.lhs.m); const int incX = get_inc(src.rhs);
const T beta = static_cast<T>(add_to?1:0); const T beta = static_cast<T>(add_to?1:0);
T* Y = get_ptr(dest); T* Y = get_ptr(dest);
...@@ -751,7 +753,7 @@ namespace dlib ...@@ -751,7 +753,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(trans(conj(m))*trans(rv)) DLIB_ADD_BLAS_BINDING(trans(conj(m))*trans(rv))
{ {
//cout << "BLAS: trans(conj(m))*trans(rv)" << endl; //cout << "BLAS GEMV: trans(conj(m))*trans(rv)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasConjTrans; const CBLAS_TRANSPOSE TransA = CblasConjTrans;
......
...@@ -25,6 +25,348 @@ namespace ...@@ -25,6 +25,348 @@ namespace
logger dlog("test.matrix3"); logger dlog("test.matrix3");
const double eps_mul = 200000;
template <typename T, typename U>
void check_equal (
const T& a,
const U& b
)
{
DLIB_CASSERT(a.nr() == b.nr(),"");
DLIB_CASSERT(a.nc() == b.nc(),"");
typedef typename T::type type;
for (long r = 0; r < a.nr(); ++r)
{
for (long c = 0; c < a.nc(); ++c)
{
type error = std::abs(a(r,c) - b(r,c));
DLIB_CASSERT(error < std::numeric_limits<type>::epsilon()*eps_mul, "error: " << error <<
" eps: " << std::numeric_limits<type>::epsilon()*eps_mul);
}
}
}
template <typename T, typename U>
void c_check_equal (
const T& a,
const U& b
)
{
DLIB_CASSERT(a.nr() == b.nr(),"");
DLIB_CASSERT(a.nc() == b.nc(),"");
typedef typename T::type type;
for (long r = 0; r < a.nr(); ++r)
{
for (long c = 0; c < a.nc(); ++c)
{
typename type::value_type error = std::abs(a(r,c) - b(r,c));
DLIB_CASSERT(error < std::numeric_limits<typename type::value_type>::epsilon()*eps_mul, "error: " << error <<
" eps: " << std::numeric_limits<typename type::value_type>::epsilon()*eps_mul);
}
}
}
template <typename T, typename U>
void assign_no_blas (
const T& a_,
const U& b
)
{
T& a = const_cast<T&>(a_);
DLIB_CASSERT(a.nr() == b.nr(),"");
DLIB_CASSERT(a.nc() == b.nc(),"");
for (long r = 0; r < a.nr(); ++r)
{
for (long c = 0; c < a.nc(); ++c)
{
a(r,c) = b(r,c);
}
}
}
template <typename type>
void test_blas( long rows, long cols)
{
// The tests in this function exercise the BLAS bindings located in the matrix/matrix_blas_bindings.h file.
// It does this by performing an assignment that is subject to BLAS bindings and comparing the
// results directly to an unevaluated matrix_exp that should be equal.
dlib::rand::float_1a rnd;
matrix<type> a(rows,cols), temp, temp2;
for (int i = 0; i < 6; ++i)
{
for (long r= 0; r < a.nr(); ++r)
{
for (long c = 0; c < a.nc(); ++c)
{
a(r,c) = static_cast<type>(10*rnd.get_random_double());
}
}
matrix<type> at;
at = trans(a);
matrix<complex<type> > c_a(rows,cols), c_at;
for (long r= 0; r < a.nr(); ++r)
{
for (long c = 0; c < a.nc(); ++c)
{
c_a(r,c) = complex<type>(10*rnd.get_random_double(),10*rnd.get_random_double());
}
}
c_at = trans(c_a);
matrix<complex<type> > c_temp(cols,cols), c_temp2(cols,cols);
const complex<type> i(0,1);
const type one = 1;
const type two = 1;
const type num1 = 3.6;
const type num2 = 6.6;
const type num3 = 8.6;
matrix<complex<type>,0,1> c_cv4(cols), c_cv3(rows);
matrix<complex<type>,1,0> c_rv4(cols), c_rv3(rows);
matrix<type,0,1> cv4(cols);
for (long i = 0; i < cv4.size(); ++i)
cv4(i) = 10*rnd.get_random_double();
for (long i = 0; i < c_cv4.size(); ++i)
c_cv4(i) = complex<type>(10*rnd.get_random_double(),10*rnd.get_random_double());
matrix<type,1,0> rv3(rows);
for (long i = 0; i < rv3.size(); ++i)
rv3(i) = 10*rnd.get_random_double();
for (long i = 0; i < c_rv3.size(); ++i)
c_rv3(i) = complex<type>(10*rnd.get_random_double(),10*rnd.get_random_double());
matrix<type,0,1> cv3(rows);
for (long i = 0; i < cv3.size(); ++i)
cv3(i) = 10*rnd.get_random_double();
for (long i = 0; i < c_cv3.size(); ++i)
c_cv3(i) = complex<type>(10*rnd.get_random_double(),10*rnd.get_random_double());
matrix<type,1,0> rv4(cols);
for (long i = 0; i < rv4.size(); ++i)
rv4(i) = 10*rnd.get_random_double();
for (long i = 0; i < c_rv4.size(); ++i)
c_rv4(i) = complex<type>(10*rnd.get_random_double(),10*rnd.get_random_double());
// GEMM tests
dlog << LTRACE << "1.1";
check_equal(tmp(at*a), at*a);
dlog << LTRACE << "1.2";
check_equal(tmp(trans(a)*a), trans(a)*a);
dlog << LTRACE << "1.3";
check_equal(tmp(at*trans(at)), at*trans(at));
dlog << LTRACE << "1.4";
check_equal(tmp(trans(at)*trans(a)), a*at);
dlog << LTRACE << "1.5";
print_spinner();
c_check_equal(tmp(conj(trans(c_a))*c_a), trans(conj(c_a))*c_a);
dlog << LTRACE << "1.6";
c_check_equal(tmp(c_at*trans(conj(c_at))), c_at*conj(trans(c_at)));
dlog << LTRACE << "1.7";
c_check_equal(tmp(conj(trans(c_at))*trans(conj(c_a))), conj(trans(c_at))*trans(conj(c_a)));
dlog << LTRACE << "1.8";
check_equal(tmp(a*trans(rowm(a,1))) , a*trans(rowm(a,1)));
check_equal(tmp(a*colm(at,1)) , a*colm(at,1));
check_equal(tmp(subm(a,1,1,2,2)*subm(a,1,2,2,2)), subm(a,1,1,2,2)*subm(a,1,2,2,2));
temp = at*a;
temp2 = temp;
temp += 3.5*at*a;
assign_no_blas(temp2, temp2 + 3.5*at*a);
check_equal(temp, temp2);
temp -= at*3.5*a;
assign_no_blas(temp2, temp2 - at*3.5*a);
check_equal(temp, temp2);
temp = temp + 4*at*a;
assign_no_blas(temp2, temp2 + 4*at*a);
check_equal(temp, temp2);
temp = temp - 2.4*at*a;
assign_no_blas(temp2, temp2 - 2.4*at*a);
check_equal(temp, temp2);
// GEMV tests
check_equal(tmp(a*cv4), a*cv4);
check_equal(tmp(rv3*a), rv3*a);
check_equal(tmp(trans(cv4)*at), trans(cv4)*at);
check_equal(tmp(a*trans(rv4)), a*trans(rv4));
check_equal(tmp(trans(a)*cv3), trans(a)*cv3);
check_equal(tmp(rv4*trans(a)), rv4*trans(a));
check_equal(tmp(trans(cv3)*trans(at)), trans(cv3)*trans(at));
check_equal(tmp(trans(cv3)*a), trans(cv3)*a);
check_equal(tmp(trans(a)*trans(rv3)), trans(a)*trans(rv3));
c_check_equal(tmp(trans(conj(c_a))*c_cv3), trans(conj(c_a))*c_cv3);
c_check_equal(tmp(c_rv4*trans(conj(c_a))), c_rv4*trans(conj(c_a)));
c_check_equal(tmp(trans(c_cv3)*trans(conj(c_at))), trans(c_cv3)*trans(conj(c_at)));
c_check_equal(tmp(conj(trans(c_a))*trans(c_rv3)), trans(conj(c_a))*trans(c_rv3));
c_check_equal(tmp(c_rv4*conj(c_at)), c_rv4*conj(c_at));
c_check_equal(tmp(trans(c_cv4)*conj(c_at)), trans(c_cv4)*conj(c_at));
dlog << LTRACE << "6";
temp = a*at;
check_equal(temp, a*at);
temp = temp + a*at + trans(at)*at + trans(at)*sin(at);
check_equal(temp, a*at + a*at+ trans(at)*at + trans(at)*sin(at));
dlog << LTRACE << "6.1";
temp = a*at;
check_equal(temp, a*at);
temp = a*at + temp;
check_equal(temp, a*at + a*at);
print_spinner();
dlog << LTRACE << "6.2";
temp = a*at;
check_equal(temp, a*at);
dlog << LTRACE << "6.2.3";
temp = temp - a*at;
dlog << LTRACE << "6.2.4";
check_equal(temp, a*at-a*at);
dlog << LTRACE << "6.3";
temp = a*at;
dlog << LTRACE << "6.3.5";
check_equal(temp, a*at);
dlog << LTRACE << "6.3.6";
temp = a*at - temp;
dlog << LTRACE << "6.4";
check_equal(temp, a*at-a*at);
const long d = min(rows,cols);
rectangle rect(1,1,d,d);
temp.set_size(max(rows,cols)+4,max(rows,cols)+4);
set_all_elements(temp,4);
temp2 = temp;
dlog << LTRACE << "7";
set_subm(temp,rect) = a*at;
assign_no_blas( set_subm(temp2,rect) , a*at);
check_equal(temp, temp2);
temp = a;
temp2 = a;
set_colm(temp,1) = a*cv4;
assign_no_blas( set_colm(temp2,1) , a*cv4);
check_equal(temp, temp2);
set_rowm(temp,1) = rv3*a;
assign_no_blas( set_rowm(temp2,1) , rv3*a);
check_equal(temp, temp2);
// Test BLAS GER
temp.set_size(cols,cols);
set_all_elements(temp,3);
temp2 = temp;
dlog << LTRACE << "8";
temp += cv4*rv4;
assign_no_blas(temp2, temp2 + cv4*rv4);
check_equal(temp, temp2);
dlog << LTRACE << "8.3";
temp = temp + cv4*rv4;
assign_no_blas(temp2, temp2 + cv4*rv4);
check_equal(temp, temp2);
dlog << LTRACE << "8.9";
set_all_elements(c_temp, one + num1*i);
c_temp2 = c_temp;
set_all_elements(c_rv4, one + num2*i);
set_all_elements(c_cv4, two + num3*i);
dlog << LTRACE << "9";
c_temp += c_cv4*c_rv4;
assign_no_blas(c_temp2, c_temp2 + c_cv4*c_rv4);
c_check_equal(c_temp, c_temp2);
dlog << LTRACE << "9.1";
c_temp += c_cv4*conj(c_rv4);
assign_no_blas(c_temp2, c_temp2 + c_cv4*conj(c_rv4));
c_check_equal(c_temp, c_temp2);
dlog << LTRACE << "9.2";
c_temp = c_cv4*conj(c_rv4) + c_temp;
assign_no_blas(c_temp2, c_temp2 + c_cv4*conj(c_rv4));
c_check_equal(c_temp, c_temp2);
dlog << LTRACE << "9.3";
c_temp = trans(c_rv4)*trans(conj(c_cv4)) + c_temp;
assign_no_blas(c_temp2, c_temp2 + trans(c_rv4)*trans(conj(c_cv4)));
c_check_equal(c_temp, c_temp2);
dlog << LTRACE << "10";
print_spinner();
// Test DOT
check_equal( tmp(rv4*cv4), rv4*cv4);
check_equal( tmp(trans(cv4)*trans(rv4)), trans(cv4)*trans(rv4));
check_equal( tmp(rv4*3.9*cv4), rv4*3.9*cv4);
check_equal( tmp(trans(cv4)*3.9*trans(rv4)), trans(cv4)*3.9*trans(rv4));
check_equal( tmp(rv4*cv4*3.9), rv4*3.9*cv4);
check_equal( tmp(trans(cv4)*trans(rv4)*3.9), trans(cv4)*3.9*trans(rv4));
temp.set_size(1,1);
temp = 4;
check_equal( tmp(temp + rv4*cv4), temp + rv4*cv4);
check_equal( tmp(temp + trans(cv4)*trans(rv4)), temp + trans(cv4)*trans(rv4));
dlog << LTRACE << "11";
c_check_equal( tmp(conj(c_rv4)*c_cv4), conj(c_rv4)*c_cv4);
c_check_equal( tmp(conj(trans(c_cv4))*trans(c_rv4)), trans(conj(c_cv4))*trans(c_rv4));
c_check_equal( tmp(conj(c_rv4)*i*c_cv4), conj(c_rv4)*i*c_cv4);
c_check_equal( tmp(conj(trans(c_cv4))*i*trans(c_rv4)), trans(conj(c_cv4))*i*trans(c_rv4));
c_temp.set_size(1,1);
c_temp = 4;
c_check_equal( tmp(c_temp + conj(c_rv4)*c_cv4), c_temp + conj(c_rv4)*c_cv4);
c_check_equal( tmp(c_temp + trans(conj(c_cv4))*trans(c_rv4)), c_temp + trans(conj(c_cv4))*trans(c_rv4));
DLIB_CASSERT(abs((static_cast<complex<type> >(c_rv4*c_cv4) + i) - ((c_rv4*c_cv4)(0) + i)) < std::numeric_limits<type>::epsilon()*eps_mul ,"");
DLIB_CASSERT(abs((rv4*cv4 + 1.0) - ((rv4*cv4)(0) + 1.0)) < std::numeric_limits<type>::epsilon()*eps_mul,"");
}
}
void matrix_test ( void matrix_test (
) )
/*! /*!
...@@ -52,6 +394,24 @@ namespace ...@@ -52,6 +394,24 @@ namespace
DLIB_CASSERT(subm(tensor_product(m1,m2),range(2,3), range(2,3)) == 4*m2,""); DLIB_CASSERT(subm(tensor_product(m1,m2),range(2,3), range(2,3)) == 4*m2,"");
} }
{
print_spinner();
dlog << LTRACE << "testing blas stuff";
dlog << LTRACE << " \nsmall double";
test_blas<double>(3,4);
print_spinner();
dlog << LTRACE << " \nsmall float";
test_blas<float>(3,4);
print_spinner();
dlog << LTRACE << " \nbig double";
test_blas<double>(120,131);
print_spinner();
dlog << LTRACE << " \nbig float";
test_blas<float>(120,131);
print_spinner();
dlog << LTRACE << "testing done";
}
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment