Commit f1c734d6 authored by Davis King's avatar Davis King

Fixed a bug pointed out by Ernesto Tapia that could cause matrix expressions

that involve sub matrix views (e.g. colm) to produce the wrong results when the
BLAS bindings were enabled.
parent f6ad191c
...@@ -423,6 +423,10 @@ namespace dlib ...@@ -423,6 +423,10 @@ namespace dlib
// -------- // --------
// get_inc() returns the offset from one element to another. If an object has a
// non-uniform offset between elements then returns 0 (e.g. a subm() view could
// have a non-uniform offset between elements).
template <typename T, typename MM> template <typename T, typename MM>
int get_inc (const matrix_op<op_array2d_to_mat<array2d<T,MM> > >& ) { return 1; } int get_inc (const matrix_op<op_array2d_to_mat<array2d<T,MM> > >& ) { return 1; }
template <typename T, typename MM> template <typename T, typename MM>
...@@ -439,6 +443,43 @@ namespace dlib ...@@ -439,6 +443,43 @@ namespace dlib
template <typename T, long NR, long NC, typename MM, typename L> template <typename T, long NR, long NC, typename MM, typename L>
int get_inc (const matrix<T,NR,NC,MM,L>& ) { return 1; } int get_inc (const matrix<T,NR,NC,MM,L>& ) { return 1; }
template <typename T, long NR, long NC, typename MM>
int get_inc (const matrix_op<op_subm<matrix<T,NR,NC,MM,row_major_layout> > >& m)
{
// if the sub-view doesn't cover all the columns then it can't have a uniform
// layout.
if (m.nc() < m.op.m.nc())
return 0;
else
return 1;
}
template <typename T, long NR, long NC, typename MM>
int get_inc (const matrix_op<op_subm<matrix<T,NR,NC,MM,column_major_layout> > >& m)
{
if (m.nr() < m.op.m.nr())
return 0;
else
return 1;
}
template <typename T, long NR, long NC, typename MM>
int get_inc (const assignable_sub_matrix<T,NR,NC,MM,row_major_layout>& m)
{
if (m.nc() < m.m.nc())
return 0;
else
return 1;
}
template <typename T, long NR, long NC, typename MM>
int get_inc (const assignable_sub_matrix<T,NR,NC,MM,column_major_layout>& m)
{
if (m.nr() < m.m.nr())
return 0;
else
return 1;
}
template <typename T, long NR, long NC, typename MM> template <typename T, long NR, long NC, typename MM>
int get_inc(const matrix_op<op_colm<matrix<T,NR,NC,MM,row_major_layout> > >& m) int get_inc(const matrix_op<op_colm<matrix<T,NR,NC,MM,row_major_layout> > >& m)
{ {
...@@ -589,18 +630,17 @@ namespace dlib ...@@ -589,18 +630,17 @@ namespace dlib
{ {
if (add_to) if (add_to)
{ {
cblas_axpy(N, alpha, get_ptr(src), 1, get_ptr(dest), 1); if (get_inc(src) && get_inc(dest))
cblas_axpy(N, alpha, get_ptr(src), get_inc(src), get_ptr(dest), get_inc(dest));
else
matrix_assign_default(dest, src, alpha, add_to);
} }
else else
{ {
if (get_ptr(src) == get_ptr(dest)) if (get_ptr(src) == get_ptr(dest))
{
cblas_scal(N, alpha, get_ptr(dest)); cblas_scal(N, alpha, get_ptr(dest));
}
else else
{
matrix_assign_default(dest, src, alpha, add_to); matrix_assign_default(dest, src, alpha, add_to);
}
} }
} }
else else
...@@ -618,18 +658,17 @@ namespace dlib ...@@ -618,18 +658,17 @@ namespace dlib
{ {
if (add_to) if (add_to)
{ {
cblas_axpy(N, alpha, get_ptr(src), 1, get_ptr(dest), 1); if (get_inc(src) && get_inc(dest))
cblas_axpy(N, alpha, get_ptr(src), get_inc(src), get_ptr(dest), get_inc(dest));
else
matrix_assign_default(dest, src, alpha, add_to);
} }
else else
{ {
if (get_ptr(src) == get_ptr(dest)) if (get_ptr(src) == get_ptr(dest))
{
cblas_scal(N, alpha, get_ptr(dest)); cblas_scal(N, alpha, get_ptr(dest));
}
else else
{
matrix_assign_default(dest, src, alpha, add_to); matrix_assign_default(dest, src, alpha, add_to);
}
} }
} }
else else
...@@ -647,18 +686,17 @@ namespace dlib ...@@ -647,18 +686,17 @@ namespace dlib
{ {
if (add_to) if (add_to)
{ {
cblas_axpy(N, alpha, get_ptr(src), 1, get_ptr(dest), 1); if (get_inc(src) && get_inc(dest))
cblas_axpy(N, alpha, get_ptr(src), get_inc(src), get_ptr(dest), get_inc(dest));
else
matrix_assign_default(dest, src, alpha, add_to);
} }
else else
{ {
if (get_ptr(src) == get_ptr(dest)) if (get_ptr(src) == get_ptr(dest))
{
cblas_scal(N, alpha, get_ptr(dest)); cblas_scal(N, alpha, get_ptr(dest));
}
else else
{
matrix_assign_default(dest, src, alpha, add_to); matrix_assign_default(dest, src, alpha, add_to);
}
} }
} }
else else
......
...@@ -1074,6 +1074,40 @@ namespace ...@@ -1074,6 +1074,40 @@ namespace
} }
void test_axpy()
{
const int n = 4;
matrix<double> B = dlib::randm(n,n);
matrix<double> g = dlib::uniform_matrix<double>(n,1,0.0);
const double tau = 1;
matrix<double> p = g + tau*dlib::colm(B,0);
matrix<double> q = dlib::colm(B,0);
DLIB_TEST(length(p-q) < 1e-14);
p = tau*dlib::colm(B,0);
q = dlib::colm(B,0);
DLIB_TEST(length(p-q) < 1e-14);
g = dlib::uniform_matrix<double>(n,n,0.0);
p = g + tau*B;
DLIB_TEST(length(p-B) < 1e-14);
p = g + tau*subm(B,get_rect(B));
DLIB_TEST(length(p-B) < 1e-14);
g = dlib::uniform_matrix<double>(2,2,0.0);
p = g + tau*subm(B,1,1,2,2);
DLIB_TEST(length(p-subm(B,1,1,2,2)) < 1e-14);
set_subm(p,0,0,2,2) = g + tau*subm(B,1,1,2,2);
DLIB_TEST(length(p-subm(B,1,1,2,2)) < 1e-14);
}
class matrix_tester : public tester class matrix_tester : public tester
...@@ -1088,6 +1122,7 @@ namespace ...@@ -1088,6 +1122,7 @@ namespace
void perform_test ( void perform_test (
) )
{ {
test_axpy();
test_matrix_IO(); test_matrix_IO();
matrix_test(); matrix_test();
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment