Commit 887f979b authored by Davis King's avatar Davis King

Added stuff to the BLAS bindings so that they can now bind

to statements with matrix sub expressions as the destination
matrix.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402775
parent bfa4b3f2
This diff is collapsed.
...@@ -213,6 +213,11 @@ namespace dlib ...@@ -213,6 +213,11 @@ namespace dlib
template <typename T, long NR, long NC, typename MM> template <typename T, long NR, long NC, typename MM>
int get_ld (const matrix_sub_exp<matrix<T,NR,NC,MM,column_major_layout> >& m) { return m.m.nr(); } int get_ld (const matrix_sub_exp<matrix<T,NR,NC,MM,column_major_layout> >& m) { return m.m.nr(); }
template <typename T, long NR, long NC, typename MM>
int get_ld (const assignable_sub_matrix<T,NR,NC,MM,row_major_layout>& m) { return m.m.nc(); }
template <typename T, long NR, long NC, typename MM>
int get_ld (const assignable_sub_matrix<T,NR,NC,MM,column_major_layout>& m) { return m.m.nr(); }
// -------- // --------
...@@ -243,6 +248,31 @@ namespace dlib ...@@ -243,6 +248,31 @@ namespace dlib
return m.m.nr(); return m.m.nr();
} }
template <typename T, long NR, long NC, typename MM>
int get_inc(const assignable_row_matrix<T,NR,NC,MM,row_major_layout>& m)
{
return 1;
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const assignable_row_matrix<T,NR,NC,MM,column_major_layout>& m)
{
return m.m.nr();
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const assignable_col_matrix<T,NR,NC,MM,row_major_layout>& m)
{
return m.m.nc();
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const assignable_col_matrix<T,NR,NC,MM,column_major_layout>& m)
{
return 1;
}
// -------- // --------
template <typename T, long NR, long NC, typename MM, typename L> template <typename T, long NR, long NC, typename MM, typename L>
...@@ -260,6 +290,16 @@ namespace dlib ...@@ -260,6 +290,16 @@ namespace dlib
template <typename T, long NR, long NC, typename MM, typename L> template <typename T, long NR, long NC, typename MM, typename L>
const T* get_ptr (const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,L>,long,op_rowm>& m) { return &m.m(m.s,0); } const T* get_ptr (const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,L>,long,op_rowm>& m) { return &m.m(m.s,0); }
template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (assignable_col_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (assignable_row_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (assignable_sub_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -282,6 +322,7 @@ namespace dlib ...@@ -282,6 +322,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(m*m) DLIB_ADD_BLAS_BINDING(m*m)
{ {
//cout << "BLAS: m*m" << endl;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans; const CBLAS_TRANSPOSE TransB = CblasNoTrans;
......
...@@ -354,6 +354,10 @@ namespace dlib ...@@ -354,6 +354,10 @@ namespace dlib
class assignable_sub_matrix class assignable_sub_matrix
{ {
public: public:
typedef T type;
typedef l layout_type;
typedef matrix<T,NR,NC,mm,l> matrix_type;
assignable_sub_matrix( assignable_sub_matrix(
matrix<T,NR,NC,mm,l>& m_, matrix<T,NR,NC,mm,l>& m_,
const rectangle& rect_ const rectangle& rect_
...@@ -367,6 +371,14 @@ namespace dlib ...@@ -367,6 +371,14 @@ namespace dlib
return m(r+rect.top(),c+rect.left()); return m(r+rect.top(),c+rect.left());
} }
const T& operator() (
long r,
long c
) const
{
return m(r+rect.top(),c+rect.left());
}
long nr() const { return rect.height(); } long nr() const { return rect.height(); }
long nc() const { return rect.width(); } long nc() const { return rect.width(); }
...@@ -413,7 +425,6 @@ namespace dlib ...@@ -413,7 +425,6 @@ namespace dlib
return *this; return *this;
} }
private:
matrix<T,NR,NC,mm,l>& m; matrix<T,NR,NC,mm,l>& m;
const rectangle rect; const rectangle rect;
...@@ -615,6 +626,10 @@ namespace dlib ...@@ -615,6 +626,10 @@ namespace dlib
class assignable_col_matrix class assignable_col_matrix
{ {
public: public:
typedef T type;
typedef l layout_type;
typedef matrix<T,NR,NC,mm,l> matrix_type;
assignable_col_matrix( assignable_col_matrix(
matrix<T,NR,NC,mm,l>& m_, matrix<T,NR,NC,mm,l>& m_,
const long col_ const long col_
...@@ -628,6 +643,14 @@ namespace dlib ...@@ -628,6 +643,14 @@ namespace dlib
return m(r,col); return m(r,col);
} }
const T& operator() (
long r,
long c
) const
{
return m(r,col);
}
long nr() const { return m.nr(); } long nr() const { return m.nr(); }
long nc() const { return 1; } long nc() const { return 1; }
...@@ -670,7 +693,6 @@ namespace dlib ...@@ -670,7 +693,6 @@ namespace dlib
return *this; return *this;
} }
private:
matrix<T,NR,NC,mm,l>& m; matrix<T,NR,NC,mm,l>& m;
const long col; const long col;
...@@ -702,6 +724,10 @@ namespace dlib ...@@ -702,6 +724,10 @@ namespace dlib
class assignable_row_matrix class assignable_row_matrix
{ {
public: public:
typedef T type;
typedef l layout_type;
typedef matrix<T,NR,NC,mm,l> matrix_type;
assignable_row_matrix( assignable_row_matrix(
matrix<T,NR,NC,mm,l>& m_, matrix<T,NR,NC,mm,l>& m_,
const long row_ const long row_
...@@ -716,6 +742,14 @@ namespace dlib ...@@ -716,6 +742,14 @@ namespace dlib
return m(row,c); return m(row,c);
} }
const T& operator() (
long r,
long c
) const
{
return m(row,c);
}
long nr() const { return 1; } long nr() const { return 1; }
long nc() const { return m.nc(); } long nc() const { return m.nc(); }
...@@ -759,7 +793,6 @@ namespace dlib ...@@ -759,7 +793,6 @@ namespace dlib
return *this; return *this;
} }
private:
matrix<T,NR,NC,mm,l>& m; matrix<T,NR,NC,mm,l>& m;
const long row; const long row;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment