Commit a87e2b8e authored by Davis King's avatar Davis King

Added stuff to make matrix sub-expressions get bound into BLAS

calls.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402768
parent d1c92fb8
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "../geometry.h" #include "../geometry.h"
#include "matrix.h" #include "matrix.h"
#include "matrix_utilities.h" #include "matrix_utilities.h"
#include "matrix_subexp.h"
#include "../enable_if.h" #include "../enable_if.h"
#include "matrix_assign_fwd.h" #include "matrix_assign_fwd.h"
#include "matrix_default_mul.h" #include "matrix_default_mul.h"
...@@ -57,28 +58,77 @@ namespace dlib ...@@ -57,28 +58,77 @@ namespace dlib
struct has_matrix_multiply<matrix_unary_exp<T,OP> > struct has_matrix_multiply<matrix_unary_exp<T,OP> >
{ const static bool value = has_matrix_multiply<T>::value; }; { const static bool value = has_matrix_multiply<T>::value; };
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template <typename T, typename U> const int unknown_matrix = 0;
struct same_matrix const int general_matrix = 1;
const int row_matrix = 2;
const int column_matrix = 3;
// ------------------------------------------------------------------------------------
template <typename T>
struct matrix_type_id
{ {
const static bool value = false; const static int value = unknown_matrix;
}; };
template <typename T1, typename T2, typename L1, typename L2, long NR1, long NC1, long NR2, long NC2, typename MM1, typename MM2 > template <typename T, long NR, long NC, typename MM, typename L>
struct same_matrix <matrix<T1,NR1,NC1,MM1,L1>, matrix<T2,NR2,NC2,MM2,L2> > struct matrix_type_id<matrix<T,NR,NC,MM,L> >
{ {
/*! These two matrices are the same if they are either: const static int value = general_matrix;
- both row vectors
- both column vectors
- both general non-vector matrices
!*/
const static bool value = (NR1 == 1 && NR2 == 1) ||
(NC1==1 && NC2==1) ||
(NR1!=1 && NC1!=1 && NR2!=1 && NC2!=1);
}; };
template <typename T, long NR, typename MM, typename L>
struct matrix_type_id<matrix<T,NR,1,MM,L> >
{
const static int value = column_matrix;
};
template <typename T, long NC, typename MM, typename L>
struct matrix_type_id<matrix<T,1,NC,MM,L> >
{
const static int value = row_matrix;
};
// ------------------------------------------------------------------------------------
template <typename T, long NR, long NC, typename MM, typename L>
struct matrix_type_id<matrix_scalar_binary_exp<matrix<T,NR,NC,MM,L>,long,op_colm> >
{
const static int value = column_matrix;
};
template <typename T, long NR, long NC, typename MM, typename L>
struct matrix_type_id<matrix_scalar_binary_exp<matrix<T,NR,NC,MM,L>,long,op_rowm> >
{
const static int value = row_matrix;
};
template <typename T, long NR, long NC, typename MM, typename L>
struct matrix_type_id<matrix_sub_exp<matrix<T,NR,NC,MM,L> > >
{
const static int value = general_matrix;
};
// ------------------------------------------------------------------------------------
template <typename T, typename U>
struct same_matrix
{
const static int T_id = matrix_type_id<T>::value;
const static int U_id = matrix_type_id<U>::value;
// The check for unknown_matrix is here so that we can be sure that matrix types
// other than the ones specifically enumerated above never get pushed into
// any of the BLAS bindings. So saying they are never the same as anything
// else prevents them from matching any of the BLAS bindings.
const static bool value = (T_id == U_id) && (T_id != unknown_matrix);
};
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// This template struct is used to tell us if two matrix expressions both contain the same // This template struct is used to tell us if two matrix expressions both contain the same
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
#include "cblas.h" #include "cblas.h"
//#include <iostream> #include <iostream>
//using namespace std; using namespace std;
namespace dlib namespace dlib
{ {
...@@ -206,11 +206,43 @@ namespace dlib ...@@ -206,11 +206,43 @@ namespace dlib
template <typename T, long NR, long NC, typename MM> template <typename T, long NR, long NC, typename MM>
int get_ld (const matrix<T,NR,NC,MM,column_major_layout>& m) { return m.nr(); } int get_ld (const matrix<T,NR,NC,MM,column_major_layout>& m) { return m.nr(); }
template <typename T, long NR, long NC, typename MM>
int get_ld (const matrix_sub_exp<matrix<T,NR,NC,MM,row_major_layout> >& m) { return m.m.nc(); }
template <typename T, long NR, long NC, typename MM>
int get_ld (const matrix_sub_exp<matrix<T,NR,NC,MM,column_major_layout> >& m) { return m.m.nr(); }
// -------- // --------
template <typename T, long NR, long NC, typename MM, typename L> template <typename T, long NR, long NC, typename MM, typename L>
int get_inc (const matrix<T,NR,NC,MM,L>& ) { return 1; } int get_inc (const matrix<T,NR,NC,MM,L>& ) { return 1; }
template <typename T, long NR, long NC, typename MM>
int get_inc(const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,row_major_layout>,long,op_colm>& m)
{
return m.m.nc();
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,row_major_layout>,long,op_rowm>& m)
{
return 1;
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,column_major_layout>,long,op_colm>& m)
{
return 1;
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,column_major_layout>,long,op_rowm>& m)
{
return m.m.nr();
}
// -------- // --------
template <typename T, long NR, long NC, typename MM, typename L> template <typename T, long NR, long NC, typename MM, typename L>
...@@ -219,6 +251,15 @@ namespace dlib ...@@ -219,6 +251,15 @@ namespace dlib
template <typename T, long NR, long NC, typename MM, typename L> template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (matrix<T,NR,NC,MM,L>& m) { return &m(0,0); } T* get_ptr (matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
template <typename T, long NR, long NC, typename MM, typename L>
const T* get_ptr (const matrix_sub_exp<matrix<T,NR,NC,MM,L> >& m) { return &m.m(m.r_,m.c_); }
template <typename T, long NR, long NC, typename MM, typename L>
const T* get_ptr (const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,L>,long,op_colm>& m) { return &m.m(0,m.s); }
template <typename T, long NR, long NC, typename MM, typename L>
const T* get_ptr (const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,L>,long,op_rowm>& m) { return &m.m(m.s,0); }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -375,7 +375,6 @@ namespace dlib ...@@ -375,7 +375,6 @@ namespace dlib
long nc ( long nc (
) const { return OP::nc(m); } ) const { return OP::nc(m); }
private:
const M& m; const M& m;
const S s; const S s;
...@@ -1431,7 +1430,6 @@ namespace dlib ...@@ -1431,7 +1430,6 @@ namespace dlib
long nc ( long nc (
) const { return nc_; } ) const { return nc_; }
private:
const M& m; const M& m;
const long r_, c_, nr_, nc_; const long r_, c_, nr_, nc_;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment