Commit 4554263b authored by Davis King's avatar Davis King

Cleaned up a few things and also fixed the compiler error in the

tensor_product() function.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402756
parent a42e7fa8
......@@ -1297,7 +1297,7 @@ namespace dlib
COMPILE_TIME_ASSERT((is_same_type<typename EXP::type,type>::value == true));
if (m.destructively_aliases(*this) == false)
{
matrix_assign(*this, m + *this);
matrix_assign(*this, *this + m);
}
else
{
......@@ -1305,7 +1305,7 @@ namespace dlib
// this->data is aliased inside the matrix_exp m somewhere.
matrix temp;
temp.set_size(m.nr(),m.nc());
matrix_assign(temp, m + *this);
matrix_assign(temp, *this + m);
temp.swap(*this);
}
return *this;
......
......@@ -26,24 +26,24 @@ namespace dlib
// ------------------------------------------------------------------------------------
// This template struct is used to tell is if two matrix types both contain the same type of
// element, have the same layout, and are both general matrices or both vectors.
template <typename T, typename U>
struct same_matrix
{
const static bool value = false;
};
template <typename T, typename L, long a, long b, long c, long d, typename MM1, typename MM2 >
struct same_matrix <matrix<T,a,b,MM1,L>, matrix<T,c,d,MM2,L> >
template <typename T1, typename T2, typename L1, typename L2, long NR1, long NC1, long NR2, long NC2, typename MM1, typename MM2 >
struct same_matrix <matrix<T1,NR1,NC1,MM1,L1>, matrix<T2,NR2,NC2,MM2,L2> >
{
/*! These two matrices are the same if they are either:
- both row vectors
- both column vectors
- both not any kind of vector
- both general matrices with the same kind of layout type
!*/
const static bool value = (a == 1 && c == 1) || (b==1 && d==1) || (a!=1 && b!=1 && c!=1 && d!=1) ;
const static bool value = (NR1 == 1 && NR2 == 1) ||
(NC1==1 && NC2==1) ||
(NR1!=1 && NC1!=1 && NR2!=1 && NC2!=1 && is_same_type<L1,L2>::value);
};
// ------------------------------------------------------------------------------------
......@@ -149,6 +149,7 @@ namespace dlib
const src_exp& src \
) {
#define DLIB_END_BLAS_BINDING }};
// ------------------------------------------------------------------------------------
......
......@@ -18,21 +18,28 @@ namespace dlib
// ----------------------------------------------------------------------------------------
typedef memory_manager<char>::kernel_1a mm;
extern matrix<double,0,0,mm,row_major_layout> dm;
extern matrix<float,0,0,mm,row_major_layout> sm;
// Here we declare some matrix objects for use in the DLIB_ADD_BLAS_BINDING macro. These
// extern declarations don't actually correspond to any real matrix objects. They are
// simply here so we can build matrix expressions with the DLIB_ADD_BLAS_BINDING marco.
extern matrix<double,1,0,mm,row_major_layout> drv;
extern matrix<double,0,1,mm,row_major_layout> dcv;
typedef memory_manager<char>::kernel_1a mm;
// Note that the fact that these are double matrices isn't important. The type
// that matters is the one that is the first argument of the DLIB_ADD_BLAS_BINDING.
// That type determines what the type of the elements of the matrices that we
// are dealing with is.
extern matrix<double,0,0,mm,row_major_layout> rm; // general matrix with row major order
extern matrix<double,0,0,mm,column_major_layout> cm; // general matrix with column major order
extern matrix<double,1,0> rv; // general row vector
extern matrix<double,0,1> cv; // general column vector
extern matrix<float,1,0,mm,row_major_layout> srv;
extern matrix<float,0,1,mm,row_major_layout> scv;
using namespace std;
#ifdef DLIB_FOUND_BLAS
DLIB_ADD_BLAS_BINDING(double, row_major_layout, dm*dm)
DLIB_ADD_BLAS_BINDING(double, row_major_layout, rm*rm)
{
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
......@@ -50,10 +57,11 @@ namespace dlib
const int ldc = src.nc();
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}};
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(double, row_major_layout, trans(dm)*dm)
DLIB_ADD_BLAS_BINDING(double, row_major_layout, trans(rm)*rm)
{
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
......@@ -71,12 +79,13 @@ namespace dlib
const int ldc = src.nc();
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}};
} DLIB_END_BLAS_BINDING
// double overloads
// -------------------------- float overloads --------------------------
DLIB_ADD_BLAS_BINDING(float, row_major_layout, sm*sm)
DLIB_ADD_BLAS_BINDING(float, row_major_layout, rm*rm)
{
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
......@@ -94,10 +103,11 @@ namespace dlib
const int ldc = src.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}};
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(float, row_major_layout, trans(sm)*sm)
DLIB_ADD_BLAS_BINDING(float, row_major_layout, trans(rm)*rm)
{
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
......@@ -115,7 +125,7 @@ namespace dlib
const int ldc = src.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}};
} DLIB_END_BLAS_BINDING
#endif // DLIB_FOUND_BLAS
......
......@@ -3203,14 +3203,14 @@ convergence:
typename EXP1,
typename EXP2
>
inline const matrix_exp<matrix_binary_exp<EXP1,EXP2,op_tensor_product> > tensor_product (
inline const matrix_binary_exp<EXP1,EXP2,op_tensor_product> tensor_product (
const matrix_exp<EXP1>& a,
const matrix_exp<EXP2>& b
)
{
COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
typedef matrix_binary_exp<EXP1,EXP2,op_tensor_product> exp;
return matrix_exp<exp>(exp(a.ref(),b.ref()));
return exp(a.ref(),b.ref());
}
// ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment