Commit a87e2b8e authored by Davis King's avatar Davis King

Added stuff to make matrix sub-expressions get bound into BLAS

calls.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402768
parent d1c92fb8
......@@ -6,6 +6,7 @@
#include "../geometry.h"
#include "matrix.h"
#include "matrix_utilities.h"
#include "matrix_subexp.h"
#include "../enable_if.h"
#include "matrix_assign_fwd.h"
#include "matrix_default_mul.h"
......@@ -57,28 +58,77 @@ namespace dlib
struct has_matrix_multiply<matrix_unary_exp<T,OP> >
{ const static bool value = has_matrix_multiply<T>::value; };
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
template <typename T, typename U>
struct same_matrix
const int unknown_matrix = 0;
const int general_matrix = 1;
const int row_matrix = 2;
const int column_matrix = 3;
// ------------------------------------------------------------------------------------
template <typename T>
struct matrix_type_id
{
const static bool value = false;
const static int value = unknown_matrix;
};
template <typename T1, typename T2, typename L1, typename L2, long NR1, long NC1, long NR2, long NC2, typename MM1, typename MM2 >
struct same_matrix <matrix<T1,NR1,NC1,MM1,L1>, matrix<T2,NR2,NC2,MM2,L2> >
{
/*! These two matrices are the same if they are either:
- both row vectors
- both column vectors
- both general non-vector matrices
!*/
const static bool value = (NR1 == 1 && NR2 == 1) ||
(NC1==1 && NC2==1) ||
(NR1!=1 && NC1!=1 && NR2!=1 && NC2!=1);
template <typename T, long NR, long NC, typename MM, typename L>
struct matrix_type_id<matrix<T,NR,NC,MM,L> >
{
const static int value = general_matrix;
};
template <typename T, long NR, typename MM, typename L>
struct matrix_type_id<matrix<T,NR,1,MM,L> >
{
const static int value = column_matrix;
};
template <typename T, long NC, typename MM, typename L>
struct matrix_type_id<matrix<T,1,NC,MM,L> >
{
const static int value = row_matrix;
};
// ------------------------------------------------------------------------------------
template <typename T, long NR, long NC, typename MM, typename L>
struct matrix_type_id<matrix_scalar_binary_exp<matrix<T,NR,NC,MM,L>,long,op_colm> >
{
const static int value = column_matrix;
};
template <typename T, long NR, long NC, typename MM, typename L>
struct matrix_type_id<matrix_scalar_binary_exp<matrix<T,NR,NC,MM,L>,long,op_rowm> >
{
const static int value = row_matrix;
};
template <typename T, long NR, long NC, typename MM, typename L>
struct matrix_type_id<matrix_sub_exp<matrix<T,NR,NC,MM,L> > >
{
const static int value = general_matrix;
};
// ------------------------------------------------------------------------------------
template <typename T, typename U>
struct same_matrix
{
const static int T_id = matrix_type_id<T>::value;
const static int U_id = matrix_type_id<U>::value;
// The check for unknown_matrix is here so that we can be sure that matrix types
// other than the ones specifically enumerated above never get pushed into
// any of the BLAS bindings. So saying they are never the same as anything
// else prevents them from matching any of the BLAS bindings.
const static bool value = (T_id == U_id) && (T_id != unknown_matrix);
};
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
// This template struct is used to tell us if two matrix expressions both contain the same
......
......@@ -11,8 +11,8 @@
#include "cblas.h"
//#include <iostream>
//using namespace std;
#include <iostream>
using namespace std;
namespace dlib
{
......@@ -206,11 +206,43 @@ namespace dlib
template <typename T, long NR, long NC, typename MM>
int get_ld (const matrix<T,NR,NC,MM,column_major_layout>& m) { return m.nr(); }
template <typename T, long NR, long NC, typename MM>
int get_ld (const matrix_sub_exp<matrix<T,NR,NC,MM,row_major_layout> >& m) { return m.m.nc(); }
template <typename T, long NR, long NC, typename MM>
int get_ld (const matrix_sub_exp<matrix<T,NR,NC,MM,column_major_layout> >& m) { return m.m.nr(); }
// --------
template <typename T, long NR, long NC, typename MM, typename L>
int get_inc (const matrix<T,NR,NC,MM,L>& ) { return 1; }
template <typename T, long NR, long NC, typename MM>
int get_inc(const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,row_major_layout>,long,op_colm>& m)
{
return m.m.nc();
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,row_major_layout>,long,op_rowm>& m)
{
return 1;
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,column_major_layout>,long,op_colm>& m)
{
return 1;
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,column_major_layout>,long,op_rowm>& m)
{
return m.m.nr();
}
// --------
template <typename T, long NR, long NC, typename MM, typename L>
......@@ -219,6 +251,15 @@ namespace dlib
template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
template <typename T, long NR, long NC, typename MM, typename L>
const T* get_ptr (const matrix_sub_exp<matrix<T,NR,NC,MM,L> >& m) { return &m.m(m.r_,m.c_); }
template <typename T, long NR, long NC, typename MM, typename L>
const T* get_ptr (const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,L>,long,op_colm>& m) { return &m.m(0,m.s); }
template <typename T, long NR, long NC, typename MM, typename L>
const T* get_ptr (const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,L>,long,op_rowm>& m) { return &m.m(m.s,0); }
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......
......@@ -375,7 +375,6 @@ namespace dlib
long nc (
) const { return OP::nc(m); }
private:
const M& m;
const S s;
......@@ -1431,7 +1430,6 @@ namespace dlib
long nc (
) const { return nc_; }
private:
const M& m;
const long r_, c_, nr_, nc_;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment