Commit a5e6f0a7 authored by Davis King's avatar Davis King

Fixed a bug in the matrix BLAS bindings that caused BLAS to return an invalid

argument error.  The error occurred when general matrix multiply expressions
were transposed and didn't result in a square matrix.  E.g. mat = trans(a*b)
where mat isn't square.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403368
parent 45e99cee
......@@ -488,7 +488,7 @@ namespace dlib
if (transpose == false)
cblas_gemm(Order, CblasNoTrans, CblasNoTrans, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
cblas_gemm(Order, CblasTrans, CblasTrans, M, N, K, alpha, B, ldb, A, lda, beta, C, ldc);
cblas_gemm(Order, CblasTrans, CblasTrans, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc);
} DLIB_END_BLAS_BINDING
......@@ -516,7 +516,7 @@ namespace dlib
if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, B, ldb, A, lda, beta, C, ldc);
cblas_gemm(Order, TransA, TransB, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc);
} DLIB_END_BLAS_BINDING
......@@ -544,7 +544,7 @@ namespace dlib
if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, B, ldb, A, lda, beta, C, ldc);
cblas_gemm(Order, TransA, TransB, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc);
} DLIB_END_BLAS_BINDING
......@@ -570,7 +570,7 @@ namespace dlib
if (transpose == false)
cblas_gemm(Order, CblasTrans, CblasTrans, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
cblas_gemm(Order, CblasNoTrans, CblasNoTrans, M, N, K, alpha, B, ldb, A, lda, beta, C, ldc);
cblas_gemm(Order, CblasNoTrans, CblasNoTrans, N, M, K, alpha, B, ldb, A, lda, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......
......@@ -272,6 +272,43 @@ namespace
check_equal(temp, temp3);
}
dlog << LTRACE << "1.17.1";
{
matrix<type> m1, m2;
m1 = matrix_cast<type>(randm(rows, cols, rnd));
m2 = matrix_cast<type>(randm(cols, rows + 8, rnd));
check_equal(tmp(m1*m2), m1*m2);
check_equal(tmp(trans(m1*m2)), trans(m1*m2));
m1 = trans(m1);
check_equal(tmp(trans(m1)*m2), trans(m1)*m2);
check_equal(tmp(trans(trans(m1)*m2)), trans(trans(m1)*m2));
m2 = trans(m2);
check_equal(tmp(trans(m1)*trans(m2)), trans(m1)*trans(m2));
check_equal(tmp(trans(trans(m1)*trans(m2))), trans(trans(m1)*trans(m2)));
m1 = trans(m1);
check_equal(tmp(m1*trans(m2)), m1*trans(m2));
check_equal(tmp(trans(m1*trans(m2))), trans(m1*trans(m2)));
}
dlog << LTRACE << "1.17.5";
{
matrix<type,1,0> r;
matrix<type,0,1> c;
r = matrix_cast<type>(randm(1, rows+9, rnd));
c = matrix_cast<type>(randm(rows, 1, rnd));
check_equal(tmp(c*r), c*r);
check_equal(tmp(trans(c*r)), trans(c*r));
check_equal(tmp(trans(r)*trans(c)), trans(r)*trans(c));
check_equal(tmp(trans(trans(r)*trans(c))), trans(trans(r)*trans(c)));
}
dlog << LTRACE << "1.18";
// GEMV tests
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment