Commit 4554263b authored by Davis King's avatar Davis King

Cleaned up a few things and also fixed the compiler error in the

tensor_product() function.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402756
parent a42e7fa8
...@@ -1297,7 +1297,7 @@ namespace dlib ...@@ -1297,7 +1297,7 @@ namespace dlib
COMPILE_TIME_ASSERT((is_same_type<typename EXP::type,type>::value == true)); COMPILE_TIME_ASSERT((is_same_type<typename EXP::type,type>::value == true));
if (m.destructively_aliases(*this) == false) if (m.destructively_aliases(*this) == false)
{ {
matrix_assign(*this, m + *this); matrix_assign(*this, *this + m);
} }
else else
{ {
...@@ -1305,7 +1305,7 @@ namespace dlib ...@@ -1305,7 +1305,7 @@ namespace dlib
// this->data is aliased inside the matrix_exp m somewhere. // this->data is aliased inside the matrix_exp m somewhere.
matrix temp; matrix temp;
temp.set_size(m.nr(),m.nc()); temp.set_size(m.nr(),m.nc());
matrix_assign(temp, m + *this); matrix_assign(temp, *this + m);
temp.swap(*this); temp.swap(*this);
} }
return *this; return *this;
......
...@@ -26,24 +26,24 @@ namespace dlib ...@@ -26,24 +26,24 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// This template struct is used to tell is if two matrix types both contain the same type of
// element, have the same layout, and are both general matrices or both vectors.
template <typename T, typename U> template <typename T, typename U>
struct same_matrix struct same_matrix
{ {
const static bool value = false; const static bool value = false;
}; };
template <typename T, typename L, long a, long b, long c, long d, typename MM1, typename MM2 > template <typename T1, typename T2, typename L1, typename L2, long NR1, long NC1, long NR2, long NC2, typename MM1, typename MM2 >
struct same_matrix <matrix<T,a,b,MM1,L>, matrix<T,c,d,MM2,L> > struct same_matrix <matrix<T1,NR1,NC1,MM1,L1>, matrix<T2,NR2,NC2,MM2,L2> >
{ {
/*! These two matrices are the same if they are either: /*! These two matrices are the same if they are either:
- both row vectors - both row vectors
- both column vectors - both column vectors
- both not any kind of vector - both general matrices with the same kind of layout type
!*/ !*/
const static bool value = (a == 1 && c == 1) || (b==1 && d==1) || (a!=1 && b!=1 && c!=1 && d!=1) ; const static bool value = (NR1 == 1 && NR2 == 1) ||
(NC1==1 && NC2==1) ||
(NR1!=1 && NC1!=1 && NR2!=1 && NC2!=1 && is_same_type<L1,L2>::value);
}; };
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
...@@ -149,6 +149,7 @@ namespace dlib ...@@ -149,6 +149,7 @@ namespace dlib
const src_exp& src \ const src_exp& src \
) { ) {
#define DLIB_END_BLAS_BINDING }};
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
......
...@@ -18,21 +18,28 @@ namespace dlib ...@@ -18,21 +18,28 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
typedef memory_manager<char>::kernel_1a mm; // Here we declare some matrix objects for use in the DLIB_ADD_BLAS_BINDING macro. These
extern matrix<double,0,0,mm,row_major_layout> dm; // extern declarations don't actually correspond to any real matrix objects. They are
extern matrix<float,0,0,mm,row_major_layout> sm; // simply here so we can build matrix expressions with the DLIB_ADD_BLAS_BINDING marco.
extern matrix<double,1,0,mm,row_major_layout> drv; typedef memory_manager<char>::kernel_1a mm;
extern matrix<double,0,1,mm,row_major_layout> dcv; // Note that the fact that these are double matrices isn't important. The type
// that matters is the one that is the first argument of the DLIB_ADD_BLAS_BINDING.
// That type determines what the type of the elements of the matrices that we
// are dealing with is.
extern matrix<double,0,0,mm,row_major_layout> rm; // general matrix with row major order
extern matrix<double,0,0,mm,column_major_layout> cm; // general matrix with column major order
extern matrix<double,1,0> rv; // general row vector
extern matrix<double,0,1> cv; // general column vector
extern matrix<float,1,0,mm,row_major_layout> srv;
extern matrix<float,0,1,mm,row_major_layout> scv;
using namespace std; using namespace std;
#ifdef DLIB_FOUND_BLAS #ifdef DLIB_FOUND_BLAS
DLIB_ADD_BLAS_BINDING(double, row_major_layout, dm*dm) DLIB_ADD_BLAS_BINDING(double, row_major_layout, rm*rm)
{
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans; const CBLAS_TRANSPOSE TransB = CblasNoTrans;
...@@ -50,10 +57,11 @@ namespace dlib ...@@ -50,10 +57,11 @@ namespace dlib
const int ldc = src.nc(); const int ldc = src.nc();
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}}; } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(double, row_major_layout, trans(dm)*dm) DLIB_ADD_BLAS_BINDING(double, row_major_layout, trans(rm)*rm)
{
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans; const CBLAS_TRANSPOSE TransB = CblasNoTrans;
...@@ -71,12 +79,13 @@ namespace dlib ...@@ -71,12 +79,13 @@ namespace dlib
const int ldc = src.nc(); const int ldc = src.nc();
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}}; } DLIB_END_BLAS_BINDING
// double overloads // -------------------------- float overloads --------------------------
DLIB_ADD_BLAS_BINDING(float, row_major_layout, sm*sm) DLIB_ADD_BLAS_BINDING(float, row_major_layout, rm*rm)
{
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans; const CBLAS_TRANSPOSE TransB = CblasNoTrans;
...@@ -94,10 +103,11 @@ namespace dlib ...@@ -94,10 +103,11 @@ namespace dlib
const int ldc = src.nc(); const int ldc = src.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}}; } DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(float, row_major_layout, trans(sm)*sm) DLIB_ADD_BLAS_BINDING(float, row_major_layout, trans(rm)*rm)
{
const CBLAS_ORDER Order = CblasRowMajor; const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans; const CBLAS_TRANSPOSE TransB = CblasNoTrans;
...@@ -115,7 +125,7 @@ namespace dlib ...@@ -115,7 +125,7 @@ namespace dlib
const int ldc = src.nc(); const int ldc = src.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}}; } DLIB_END_BLAS_BINDING
#endif // DLIB_FOUND_BLAS #endif // DLIB_FOUND_BLAS
......
...@@ -3203,14 +3203,14 @@ convergence: ...@@ -3203,14 +3203,14 @@ convergence:
typename EXP1, typename EXP1,
typename EXP2 typename EXP2
> >
inline const matrix_exp<matrix_binary_exp<EXP1,EXP2,op_tensor_product> > tensor_product ( inline const matrix_binary_exp<EXP1,EXP2,op_tensor_product> tensor_product (
const matrix_exp<EXP1>& a, const matrix_exp<EXP1>& a,
const matrix_exp<EXP2>& b const matrix_exp<EXP2>& b
) )
{ {
COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true)); COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
typedef matrix_binary_exp<EXP1,EXP2,op_tensor_product> exp; typedef matrix_binary_exp<EXP1,EXP2,op_tensor_product> exp;
return matrix_exp<exp>(exp(a.ref(),b.ref())); return exp(a.ref(),b.ref());
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment