Commit 887f979b authored by Davis King's avatar Davis King

Added stuff to the BLAS bindings so that they can now bind

to statements with matrix sub expressions as the destination
matrix.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402775
parent bfa4b3f2
This diff is collapsed.
......@@ -213,6 +213,11 @@ namespace dlib
template <typename T, long NR, long NC, typename MM>
int get_ld (const matrix_sub_exp<matrix<T,NR,NC,MM,column_major_layout> >& m) { return m.m.nr(); }
template <typename T, long NR, long NC, typename MM>
int get_ld (const assignable_sub_matrix<T,NR,NC,MM,row_major_layout>& m) { return m.m.nc(); }
template <typename T, long NR, long NC, typename MM>
int get_ld (const assignable_sub_matrix<T,NR,NC,MM,column_major_layout>& m) { return m.m.nr(); }
// --------
......@@ -243,6 +248,31 @@ namespace dlib
return m.m.nr();
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const assignable_row_matrix<T,NR,NC,MM,row_major_layout>& m)
{
return 1;
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const assignable_row_matrix<T,NR,NC,MM,column_major_layout>& m)
{
return m.m.nr();
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const assignable_col_matrix<T,NR,NC,MM,row_major_layout>& m)
{
return m.m.nc();
}
template <typename T, long NR, long NC, typename MM>
int get_inc(const assignable_col_matrix<T,NR,NC,MM,column_major_layout>& m)
{
return 1;
}
// --------
template <typename T, long NR, long NC, typename MM, typename L>
......@@ -260,6 +290,16 @@ namespace dlib
template <typename T, long NR, long NC, typename MM, typename L>
const T* get_ptr (const matrix_scalar_binary_exp<matrix<T,NR,NC,MM,L>,long,op_rowm>& m) { return &m.m(m.s,0); }
template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (assignable_col_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (assignable_row_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (assignable_sub_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......@@ -282,6 +322,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(m*m)
{
//cout << "BLAS: m*m" << endl;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
......
......@@ -354,6 +354,10 @@ namespace dlib
class assignable_sub_matrix
{
public:
typedef T type;
typedef l layout_type;
typedef matrix<T,NR,NC,mm,l> matrix_type;
assignable_sub_matrix(
matrix<T,NR,NC,mm,l>& m_,
const rectangle& rect_
......@@ -367,6 +371,14 @@ namespace dlib
return m(r+rect.top(),c+rect.left());
}
const T& operator() (
long r,
long c
) const
{
return m(r+rect.top(),c+rect.left());
}
long nr() const { return rect.height(); }
long nc() const { return rect.width(); }
......@@ -413,7 +425,6 @@ namespace dlib
return *this;
}
private:
matrix<T,NR,NC,mm,l>& m;
const rectangle rect;
......@@ -615,6 +626,10 @@ namespace dlib
class assignable_col_matrix
{
public:
typedef T type;
typedef l layout_type;
typedef matrix<T,NR,NC,mm,l> matrix_type;
assignable_col_matrix(
matrix<T,NR,NC,mm,l>& m_,
const long col_
......@@ -628,6 +643,14 @@ namespace dlib
return m(r,col);
}
const T& operator() (
long r,
long c
) const
{
return m(r,col);
}
long nr() const { return m.nr(); }
long nc() const { return 1; }
......@@ -670,7 +693,6 @@ namespace dlib
return *this;
}
private:
matrix<T,NR,NC,mm,l>& m;
const long col;
......@@ -702,6 +724,10 @@ namespace dlib
class assignable_row_matrix
{
public:
typedef T type;
typedef l layout_type;
typedef matrix<T,NR,NC,mm,l> matrix_type;
assignable_row_matrix(
matrix<T,NR,NC,mm,l>& m_,
const long row_
......@@ -716,6 +742,14 @@ namespace dlib
return m(row,c);
}
const T& operator() (
long r,
long c
) const
{
return m(row,c);
}
long nr() const { return 1; }
long nc() const { return m.nc(); }
......@@ -759,7 +793,6 @@ namespace dlib
return *this;
}
private:
matrix<T,NR,NC,mm,l>& m;
const long row;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment