Commit c596cebb authored by Davis King's avatar Davis King

Improved the BLAS binding system. It should now break expressions up into

their proper basic BLAS function calls.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402764
parent 6b5f9a8a
This diff is collapsed.
...@@ -44,6 +44,12 @@ namespace dlib ...@@ -44,6 +44,12 @@ namespace dlib
EXP1& dest, EXP1& dest,
const EXP2& src const EXP2& src
) )
/*!
requires
- src.destructively_aliases(dest) == false
ensures
- #dest == src
!*/
{ {
for (long r = 0; r < src.nr(); ++r) for (long r = 0; r < src.nr(); ++r)
{ {
...@@ -54,6 +60,83 @@ namespace dlib ...@@ -54,6 +60,83 @@ namespace dlib
} }
} }
// ----------------------------------------------------------------------------------------
template <typename EXP1, typename EXP2>
inline static void matrix_assign_default (
EXP1& dest,
const EXP2& src,
typename EXP2::type alpha,
bool add_to
)
/*!
requires
- src.destructively_aliases(dest) == false
ensures
- if (add_to == false) then
- #dest == alpha*src
- else
- #dest == dest + alpha*src
!*/
{
if (add_to)
{
if (alpha == 1)
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) += src(r,c);
}
}
}
else if (alpha == -1)
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) -= src(r,c);
}
}
}
else
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) += alpha*src(r,c);
}
}
}
}
else
{
if (alpha == 1)
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) = src(r,c);
}
}
}
else
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) = alpha*src(r,c);
}
}
}
}
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
...@@ -83,7 +166,6 @@ namespace dlib ...@@ -83,7 +166,6 @@ namespace dlib
- src.destructively_aliases(dest) == false - src.destructively_aliases(dest) == false
ensures ensures
- #dest == src - #dest == src
- the part of dest outside the above sub matrix remains unchanged
!*/ !*/
{ {
// Call src.ref() here so that the derived type of the matrix_exp shows // Call src.ref() here so that the derived type of the matrix_exp shows
...@@ -107,7 +189,6 @@ namespace dlib ...@@ -107,7 +189,6 @@ namespace dlib
- src.destructively_aliases(dest) == false - src.destructively_aliases(dest) == false
ensures ensures
- #dest == src - #dest == src
- the part of dest outside the above sub matrix remains unchanged
!*/ !*/
{ {
matrix_assign_default(dest,src.ref()); matrix_assign_default(dest,src.ref());
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include "cblas.h" #include "cblas.h"
#endif #endif
#include <iostream>
namespace dlib namespace dlib
{ {
...@@ -219,45 +221,18 @@ namespace dlib ...@@ -219,45 +221,18 @@ namespace dlib
const int M = static_cast<int>(src.nr()); const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc()); const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc()); const int K = static_cast<int>(src.lhs.nc());
const T alpha = 1;
const T* A = &src.lhs(0,0); const T* A = &src.lhs(0,0);
const int lda = src.lhs.nc(); const int lda = src.lhs.nc();
const T* B = &src.rhs(0,0); const T* B = &src.rhs(0,0);
const int ldb = src.rhs.nc(); const int ldb = src.rhs.nc();
const T beta = 0; const T beta = add_to?1:0;
T* C = &dest(0,0); T* C = &dest(0,0);
const int ldc = src.nc(); const int ldc = src.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout,rm + rm*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const T alpha = 1;
const T* A = &src.rhs.lhs(0,0);
const int lda = src.rhs.lhs.nc();
const T* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const T beta = 1;
T* C = &dest(0,0);
const int ldc = src.rhs.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(row_major_layout, trans(rm)*rm) DLIB_ADD_BLAS_BINDING(row_major_layout, trans(rm)*rm)
...@@ -268,94 +243,18 @@ namespace dlib ...@@ -268,94 +243,18 @@ namespace dlib
const int M = static_cast<int>(src.nr()); const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc()); const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc()); const int K = static_cast<int>(src.lhs.nc());
const T alpha = 1;
const T* A = &src.lhs.m(0,0); const T* A = &src.lhs.m(0,0);
const int lda = src.lhs.m.nc(); const int lda = src.lhs.m.nc();
const T* B = &src.rhs(0,0); const T* B = &src.rhs(0,0);
const int ldb = src.rhs.nc(); const int ldb = src.rhs.nc();
const T beta = 0; const T beta = add_to?1:0;
T* C = &dest(0,0); T* C = &dest(0,0);
const int ldc = src.nc(); const int ldc = src.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout, rm + s*trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc());
const int K = static_cast<int>(src.rhs.m.lhs.nc());
const T alpha = src.rhs.s;
const T* A = &src.rhs.m.lhs.m(0,0);
const int lda = src.rhs.m.lhs.m.nc();
const T* B = &src.rhs.m.rhs(0,0);
const int ldb = src.rhs.m.rhs.nc();
const T beta = 1;
T* C = &dest(0,0);
const int ldc = dest.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout, s*trans(rm)*rm)
{
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.m.nr());
const int N = static_cast<int>(src.m.nc());
const int K = static_cast<int>(src.m.lhs.nc());
const T alpha = src.s;
const T* A = &src.m.lhs.m(0,0);
const int lda = src.m.lhs.m.nc();
const T* B = &src.m.rhs(0,0);
const int ldb = src.m.rhs.nc();
const T beta = 0;
T* C = &dest(0,0);
const int ldc = dest.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(row_major_layout, rm + trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const T alpha = 1;
const T* A = &src.rhs.lhs.m(0,0);
const int lda = src.rhs.lhs.m.nc();
const T* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const T beta = 1;
T* C = &dest(0,0);
const int ldc = src.nc();
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
#endif // DLIB_USE_BLAS #endif // DLIB_USE_BLAS
} }
......
...@@ -23,24 +23,25 @@ namespace dlib ...@@ -23,24 +23,25 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template < /*! This file defines the default_matrix_multiply() function. It is a function
typename matrix_dest_type, that conforms to the following definition:
typename EXP1,
typename EXP2 template <
> typename matrix_dest_type,
typename enable_if_c<ma::matrix_is_vector<EXP1>::value == true && ma::matrix_is_vector<EXP2>::value == true>::type typename EXP1,
default_matrix_multiply ( typename EXP2
matrix_dest_type& dest, >
const EXP1& lhs, void default_matrix_multiply (
const EXP2& rhs matrix_dest_type& dest,
); const EXP1& lhs,
/*! const EXP2& rhs
requires );
- (lhs*rhs).destructively_aliases(dest) == false requires
- dest.nr() == (lhs*rhs).nr() - (lhs*rhs).destructively_aliases(dest) == false
- dest.nc() == (lhs*rhs).nc() - dest.nr() == (lhs*rhs).nr()
ensures - dest.nc() == (lhs*rhs).nc()
- #dest == dest + lhs*rhs ensures
- #dest == dest + lhs*rhs
!*/ !*/
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment