Commit 890abf73 authored by Davis King's avatar Davis King

Added set_ptrm() along with all the aliasing detection updates necessary to

allow assignment statements with mixed dlib::matrix and mat(T*) expressions
that might alias each other.  Also updated BLAS bindings to bind to set_ptrm()
assignments.
parent c45abebb
......@@ -957,6 +957,11 @@ namespace dlib
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <typename T>
class op_pointer_to_mat;
template <typename T>
class op_pointer_to_col_vect;
template <
typename T,
long num_rows,
......@@ -1549,6 +1554,13 @@ namespace dlib
const matrix_exp<U>&
) const { return false; }
// These two aliases() routines are defined in matrix_mat.h
bool aliases (
const matrix_exp<matrix_op<op_pointer_to_mat<T> > >& item
) const;
bool aliases (
const matrix_exp<matrix_op<op_pointer_to_col_vect<T> > >& item
) const;
iterator begin()
{
......
......@@ -724,6 +724,29 @@ namespace dlib
}
}
// ------------------------------------------------------------------------------------
template <
typename T,
typename src_exp
>
void matrix_assign_blas (
assignable_ptr_matrix<T>& dest,
const src_exp& src
)
{
if (src.aliases(mat(dest.ptr,dest.height,dest.width)))
{
matrix<T> temp(dest.nr(),dest.nc());
matrix_assign_blas_proxy(temp,src,1,false, false);
matrix_assign_default(dest,temp);
}
else
{
matrix_assign_blas_proxy(dest,src,1,false, false);
}
}
// ------------------------------------------------------------------------------------
template <
......@@ -888,6 +911,25 @@ namespace dlib
blas_bindings::matrix_assign_blas(dest,src);
}
// ----------------------------------------------------------------------------------------
template <
typename T,
typename src_exp
>
inline typename enable_if_c<(is_same_type<T,float>::value ||
is_same_type<T,double>::value ||
is_same_type<T,std::complex<float> >::value ||
is_same_type<T,std::complex<double> >::value) &&
blas_bindings::has_matrix_multiply<src_exp>::value
>::type matrix_assign_big (
assignable_ptr_matrix<T>& dest,
const src_exp& src
)
{
blas_bindings::matrix_assign_blas(dest,src);
}
// ----------------------------------------------------------------------------------------
template <
......
......@@ -408,6 +408,9 @@ namespace dlib
template <typename T, long NR, long NC, typename MM>
int get_ld (const assignable_sub_matrix<T,NR,NC,MM,column_major_layout>& m) { return m.m.nr(); }
template <typename T>
int get_ld (const assignable_ptr_matrix<T>& m) { return m.nc(); }
template <typename T, typename MM>
int get_ld (const matrix_op<op_array2d_to_mat<array2d<T,MM> > >& m) { return m.nc(); }
template <typename T, typename MM>
......@@ -480,6 +483,9 @@ namespace dlib
return 1;
}
template <typename T>
int get_inc (const assignable_ptr_matrix<T>& ) { return 1; }
template <typename T, long NR, long NC, typename MM>
int get_inc(const matrix_op<op_colm<matrix<T,NR,NC,MM,row_major_layout> > >& m)
{
......@@ -589,6 +595,9 @@ namespace dlib
template <typename T, long NR, long NC, typename MM, typename L>
T* get_ptr (assignable_sub_matrix<T,NR,NC,MM,L>& m) { return &m(0,0); }
template <typename T>
T* get_ptr (assignable_ptr_matrix<T>& m) { return m.ptr; }
template <typename T, typename MM>
const T* get_ptr (const matrix_op<op_array2d_to_mat<array2d<T,MM> > >& m) { return &m.op.array[0][0]; }
template <typename T, typename MM>
......
......@@ -239,7 +239,10 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <typename T>
struct op_pointer_to_col_vect : does_not_alias
struct op_pointer_to_mat;
template <typename T>
struct op_pointer_to_col_vect
{
op_pointer_to_col_vect(
const T* ptr_,
......@@ -261,6 +264,31 @@ namespace dlib
long nr () const { return size; }
long nc () const { return 1; }
template <typename U> bool aliases ( const matrix_exp<U>& ) const { return false; }
template <typename U> bool destructively_aliases ( const matrix_exp<U>& ) const { return false; }
template <long num_rows, long num_cols, typename mem_manager, typename layout>
bool aliases (
const matrix_exp<matrix<T,num_rows,num_cols, mem_manager,layout> >& item
) const
{
if (item.size() == 0)
return false;
else
return (ptr == &item(0,0));
}
inline bool aliases (
const matrix_exp<matrix_op<op_pointer_to_mat<T> > >& item
) const;
bool aliases (
const matrix_exp<matrix_op<op_pointer_to_col_vect<T> > >& item
) const
{
return item.ref().op.ptr == ptr;
}
};
// ----------------------------------------------------------------------------------------
......@@ -285,7 +313,7 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <typename T>
struct op_pointer_to_mat : does_not_alias
struct op_pointer_to_mat
{
op_pointer_to_mat(
const T* ptr_,
......@@ -309,8 +337,67 @@ namespace dlib
long nr () const { return rows; }
long nc () const { return cols; }
template <typename U> bool aliases ( const matrix_exp<U>& ) const { return false; }
template <typename U> bool destructively_aliases ( const matrix_exp<U>& ) const { return false; }
template <long num_rows, long num_cols, typename mem_manager, typename layout>
bool aliases (
const matrix_exp<matrix<T,num_rows,num_cols, mem_manager,layout> >& item
) const
{
if (item.size() == 0)
return false;
else
return (ptr == &item(0,0));
}
bool aliases (
const matrix_exp<matrix_op<op_pointer_to_mat<T> > >& item
) const
{
return item.ref().op.ptr == ptr;
}
bool aliases (
const matrix_exp<matrix_op<op_pointer_to_col_vect<T> > >& item
) const
{
return item.ref().op.ptr == ptr;
}
};
template <typename T>
bool op_pointer_to_col_vect<T>::
aliases (
const matrix_exp<matrix_op<op_pointer_to_mat<T> > >& item
) const
{
return item.ref().op.ptr == ptr;
}
template <typename T, long NR, long NC, typename MM, typename L>
bool matrix<T,NR,NC,MM,L>::aliases (
const matrix_exp<matrix_op<op_pointer_to_mat<T> > >& item
) const
{
if (size() != 0)
return item.ref().op.ptr == &data(0,0);
else
return false;
}
template <typename T, long NR, long NC, typename MM, typename L>
bool matrix<T,NR,NC,MM,L>::aliases (
const matrix_exp<matrix_op<op_pointer_to_col_vect<T> > >& item
) const
{
if (size() != 0)
return item.ref().op.ptr == &data(0,0);
else
return false;
}
// ----------------------------------------------------------------------------------------
template <
......
......@@ -8,6 +8,7 @@
#include "matrix.h"
#include "../geometry/rectangle.h"
#include "matrix_expressions.h"
#include "matrix_mat.h"
......@@ -522,6 +523,182 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <typename T>
class assignable_ptr_matrix
{
public:
typedef T type;
typedef row_major_layout layout_type;
typedef matrix<T,0,0,default_memory_manager,layout_type> matrix_type;
assignable_ptr_matrix(
T* ptr_,
long nr_,
long nc_
) : ptr(ptr_), height(nr_), width(nc_){}
T& operator() (
long r,
long c
)
{
return ptr[r*width + c];
}
const T& operator() (
long r,
long c
) const
{
return ptr[r*width + c];
}
long nr() const { return height; }
long nc() const { return width; }
template <typename EXP>
assignable_ptr_matrix& operator= (
const matrix_exp<EXP>& exp
)
{
DLIB_ASSERT( exp.nr() == height && exp.nc() == width,
"\tassignable_matrix_expression set_ptrm()"
<< "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
<< "\n\texp.nr() (source matrix): " << exp.nr()
<< "\n\texp.nc() (source matrix): " << exp.nc()
<< "\n\twidth (target matrix): " << width
<< "\n\theight (target matrix): " << height
);
if (exp.destructively_aliases(mat(ptr,height,width)) == false)
{
matrix_assign(*this, exp);
}
else
{
// make a temporary copy of the matrix we are going to assign to ptr to
// avoid aliasing issues during the copy
this->operator=(tmp(exp));
}
return *this;
}
template <typename EXP>
assignable_ptr_matrix& operator+= (
const matrix_exp<EXP>& exp
)
{
DLIB_ASSERT( exp.nr() == height && exp.nc() == width,
"\tassignable_matrix_expression set_ptrm()"
<< "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
<< "\n\texp.nr() (source matrix): " << exp.nr()
<< "\n\texp.nc() (source matrix): " << exp.nc()
<< "\n\twidth (target matrix): " << width
<< "\n\theight (target matrix): " << height
);
if (exp.destructively_aliases(mat(ptr,height,width)) == false)
{
matrix_assign(*this, mat(ptr,height,width)+exp);
}
else
{
// make a temporary copy of the matrix we are going to assign to ptr to
// avoid aliasing issues during the copy
this->operator+=(tmp(exp));
}
return *this;
}
template <typename EXP>
assignable_ptr_matrix& operator-= (
const matrix_exp<EXP>& exp
)
{
DLIB_ASSERT( exp.nr() == height && exp.nc() == width,
"\tassignable_matrix_expression set_ptrm()"
<< "\n\tYou have tried to assign to this object using a matrix that isn't the right size"
<< "\n\texp.nr() (source matrix): " << exp.nr()
<< "\n\texp.nc() (source matrix): " << exp.nc()
<< "\n\twidth (target matrix): " << width
<< "\n\theight (target matrix): " << height
);
if (exp.destructively_aliases(mat(ptr,height,width)) == false)
{
matrix_assign(*this, mat(ptr,height,width)-exp);
}
else
{
// make a temporary copy of the matrix we are going to assign to ptr to
// avoid aliasing issues during the copy
this->operator-=(tmp(exp));
}
return *this;
}
assignable_ptr_matrix& operator= (
const T& value
)
{
const long size = width*height;
for (long i = 0; i < size; ++i)
ptr[i] = value;
return *this;
}
assignable_ptr_matrix& operator+= (
const T& value
)
{
const long size = width*height;
for (long i = 0; i < size; ++i)
ptr[i] += value;
return *this;
}
assignable_ptr_matrix& operator-= (
const T& value
)
{
const long size = width*height;
for (long i = 0; i < size; ++i)
ptr[i] -= value;
return *this;
}
T* ptr;
const long height;
const long width;
};
template <typename T>
assignable_ptr_matrix<T> set_ptrm (
T* ptr,
long nr,
long nc = 1
)
{
DLIB_ASSERT(nr >= 0 && nc >= 0,
"\t assignable_matrix_expression set_ptrm(T* ptr, long nr, long nc)"
<< "\n\t The dimensions can't be negative."
<< "\n\t nr: " << nr
<< "\n\t nc: " << nc
);
return assignable_ptr_matrix<T>(ptr,nr,nc);
}
// ----------------------------------------------------------------------------------------
template <typename T, long NR, long NC, typename mm, typename l>
class assignable_sub_matrix
......
......@@ -297,6 +297,35 @@ namespace dlib
R(r,c) == m(r,cols(c))
!*/
// ----------------------------------------------------------------------------------------
template <typename T>
assignable_matrix_expression set_ptrm (
T* ptr,
long nr,
long nc = 1
);
/*!
requires
- ptr == a pointer to nr*nc elements of type T
- nr >= 0
- nc >= 0
ensures
- statements of the following form:
- set_ptrm(ptr,nr,nc) = some_matrix;
result in it being the case that:
- mat(ptr,nr,nc) == some_matrix.
- statements of the following form:
- set_ptrm(ptr,nr,nc) = scalar_value;
result in it being the case that:
- mat(ptr,nr,nc) == uniform_matrix<matrix::type>(nr,nc,scalar_value).
- In addition to the normal assignment statements using the = symbol, you may
also use the usual += and -= versions of the assignment operator. In these
cases, they have their usual effect.
!*/
// ----------------------------------------------------------------------------------------
assignable_matrix_expression set_subm (
......
......@@ -274,6 +274,32 @@ namespace
DLIB_TEST(counter_gemm() == 1);
}
{
using namespace dlib;
using namespace dlib::blas_bindings;
array2d<double> a(100,100);
array2d<double> b(100,100);
matrix<double> aa(100,100);
matrix<double> bb(100,100);
matrix<double> c;
counter_gemm() = 0;
c = mat(&a[0][0],100,100)*mat(&b[0][0],100,100);
DLIB_TEST(counter_gemm() == 1);
set_ptrm(&c(0,0),100,100) = mat(&a[0][0],100,100)*mat(&b[0][0],100,100);
DLIB_TEST(counter_gemm() == 2);
set_ptrm(&c(0,0),100,100) = aa*bb;
DLIB_TEST(counter_gemm() == 3);
counter_gemm() = 0;
c = trans(2*mat(&a[0][0],100,100)*mat(&b[0][0],100,100));
DLIB_TEST(counter_gemm() == 1);
set_ptrm(&c(0,0),100,100) = trans(2*mat(&a[0][0],100,100)*mat(&b[0][0],100,100));
DLIB_TEST(counter_gemm() == 2);
set_ptrm(&c(0,0),100,100) = trans(2*mat(a)*mat(b));
DLIB_TEST(counter_gemm() == 3);
}
print_spinner();
}
};
......
......@@ -142,13 +142,38 @@ namespace
DLIB_TEST(counter_axpy() == 1);
DLIB_TEST(max(abs(rv2 - 7)) == 0);
counter_axpy() = 0;
m2 = 1;
m = 1;
m2 = 2*m2 + m*5;
DLIB_TEST(counter_axpy() == 1);
DLIB_TEST(max(abs(m2 - 7)) == 0);
if (is_same_type<typename matrix_type::layout_type, row_major_layout>::value)
{
counter_axpy() = 0;
m2 = 1;
m = 1;
set_ptrm(&m2(0,0),m2.nr(),m2.nc()) = 2*m2 + m*5;
DLIB_TEST(max(abs(m2 - 7)) == 0);
DLIB_TEST(counter_axpy() == 1);
counter_axpy() = 0;
m2 = 1;
m = 1;
set_ptrm(&m2(0,0),m2.nr(),m2.nc()) = 2*mat(&m2(0,0),m2.nr(),m2.nc()) + mat(&m(0,0),m.nr(),m.nc())*5;
DLIB_TEST(max(abs(m2 - 7)) == 0);
DLIB_TEST(counter_axpy() == 1);
counter_axpy() = 0;
m2 = 1;
m = 1;
m2 = 2*mat(&m2(0,0),m2.nr(),m2.nc()) + mat(&m(0,0),m.nr(),m.nc())*5;
DLIB_TEST(max(abs(m2 - 7)) == 0);
DLIB_TEST(counter_axpy() == 1);
}
}
......
......@@ -1319,6 +1319,49 @@ namespace
DLIB_TEST(mm(3) == 4);
}
{
const long n = 5;
matrix<double> m1, m2, m3, truth;
m1 = randm(n,n);
m2 = randm(n,n);
m3 = randm(n,n);
truth = m1*m2;
m3 = mat(&m1(0,0),n,n)*mat(&m2(0,0),n,n);
DLIB_TEST(max(abs(truth-m3)) < 1e-13);
m3 = 0;
set_ptrm(&m3(0,0),n,n) = mat(&m1(0,0),n,n)*mat(&m2(0,0),n,n);
DLIB_TEST(max(abs(truth-m3)) < 1e-13);
set_ptrm(&m3(0,0),n,n) = m1*m2;
DLIB_TEST(max(abs(truth-m3)) < 1e-13);
// now make sure it deals with aliasing correctly.
truth = m1*m2;
m1 = mat(&m1(0,0),n,n)*mat(&m2(0,0),n,n);
DLIB_TEST(max(abs(truth-m1)) < 1e-13);
m1 = randm(n,n);
truth = m1*m2;
set_ptrm(&m1(0,0),n,n) = mat(&m1(0,0),n,n)*mat(&m2(0,0),n,n);
DLIB_TEST(max(abs(truth-m1)) < 1e-13);
m1 = randm(n,n);
truth = m1*m2;
set_ptrm(&m1(0,0),n,n) = m1*m2;
DLIB_TEST(max(abs(truth-m1)) < 1e-13);
m1 = randm(n,n);
truth = m1+m1*m2;
set_ptrm(&m1(0,0),n,n) += m1*m2;
DLIB_TEST(max(abs(truth-m1)) < 1e-13);
m1 = randm(n,n);
truth = m1-m1*m2;
set_ptrm(&m1(0,0),n,n) -= m1*m2;
DLIB_TEST(max(abs(truth-m1)) < 1e-13);
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment