Commit eb823114 authored by Davis King's avatar Davis King

Checking in changes to the matrix object that allow it to

factor expressions containing trans() operators.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402833
parent 0803ff8f
This diff is collapsed.
...@@ -326,8 +326,6 @@ namespace dlib ...@@ -326,8 +326,6 @@ namespace dlib
//cout << "BLAS GEMM: m*m" << endl; //cout << "BLAS GEMM: m*m" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.nr()); const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc()); const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc()); const int K = static_cast<int>(src.lhs.nc());
...@@ -340,7 +338,11 @@ namespace dlib ...@@ -340,7 +338,11 @@ namespace dlib
T* C = get_ptr(dest); T* C = get_ptr(dest);
const int ldc = get_ld(dest); const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); if (transpose == false)
cblas_gemm(Order, CblasNoTrans, CblasNoTrans, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
cblas_gemm(Order, CblasTrans, CblasTrans, M, N, K, alpha, B, ldb, A, lda, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
...@@ -364,7 +366,11 @@ namespace dlib ...@@ -364,7 +366,11 @@ namespace dlib
T* C = get_ptr(dest); T* C = get_ptr(dest);
const int ldc = get_ld(dest); const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, B, ldb, A, lda, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
...@@ -388,18 +394,20 @@ namespace dlib ...@@ -388,18 +394,20 @@ namespace dlib
T* C = get_ptr(dest); T* C = get_ptr(dest);
const int ldc = get_ld(dest); const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, B, ldb, A, lda, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(m)*trans(m)) DLIB_ADD_BLAS_BINDING(trans(m)*trans(m))
{ {
//cout << "BLAS GEMM: trans(m)*trans(m)" << endl; cout << "BLAS GEMM: trans(m)*trans(m)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value; const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasTrans;
const int M = static_cast<int>(src.nr()); const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc()); const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc()); const int K = static_cast<int>(src.lhs.nc());
...@@ -412,7 +420,10 @@ namespace dlib ...@@ -412,7 +420,10 @@ namespace dlib
T* C = get_ptr(dest); T* C = get_ptr(dest);
const int ldc = get_ld(dest); const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); if (transpose == false)
cblas_gemm(Order, CblasTrans, CblasTrans, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
cblas_gemm(Order, CblasNoTrans, CblasNoTrans, M, N, K, alpha, B, ldb, A, lda, beta, C, ldc);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
...@@ -438,7 +449,11 @@ namespace dlib ...@@ -438,7 +449,11 @@ namespace dlib
T* C = get_ptr(dest); T* C = get_ptr(dest);
const int ldc = get_ld(dest); const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
matrix_assign_default(dest, trans(src), alpha, add_to);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
...@@ -462,7 +477,10 @@ namespace dlib ...@@ -462,7 +477,10 @@ namespace dlib
T* C = get_ptr(dest); T* C = get_ptr(dest);
const int ldc = get_ld(dest); const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
matrix_assign_default(dest, trans(src), alpha, add_to);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
...@@ -486,7 +504,10 @@ namespace dlib ...@@ -486,7 +504,10 @@ namespace dlib
T* C = get_ptr(dest); T* C = get_ptr(dest);
const int ldc = get_ld(dest); const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
matrix_assign_default(dest, trans(src), alpha, add_to);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -795,7 +816,11 @@ namespace dlib ...@@ -795,7 +816,11 @@ namespace dlib
T* A = get_ptr(dest); T* A = get_ptr(dest);
const int lda = get_ld(dest); const int lda = get_ld(dest);
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda); if (transpose == false)
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
...@@ -818,7 +843,10 @@ namespace dlib ...@@ -818,7 +843,10 @@ namespace dlib
T* A = get_ptr(dest); T* A = get_ptr(dest);
const int lda = get_ld(dest); const int lda = get_ld(dest);
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda); if (transpose == false)
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
...@@ -841,7 +869,10 @@ namespace dlib ...@@ -841,7 +869,10 @@ namespace dlib
T* A = get_ptr(dest); T* A = get_ptr(dest);
const int lda = get_ld(dest); const int lda = get_ld(dest);
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda); if (transpose == false)
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
...@@ -864,7 +895,10 @@ namespace dlib ...@@ -864,7 +895,10 @@ namespace dlib
T* A = get_ptr(dest); T* A = get_ptr(dest);
const int lda = get_ld(dest); const int lda = get_ld(dest);
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda); if (transpose == false)
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -891,7 +925,10 @@ namespace dlib ...@@ -891,7 +925,10 @@ namespace dlib
T* A = get_ptr(dest); T* A = get_ptr(dest);
const int lda = get_ld(dest); const int lda = get_ld(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda); if (transpose == false)
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
...@@ -914,7 +951,10 @@ namespace dlib ...@@ -914,7 +951,10 @@ namespace dlib
T* A = get_ptr(dest); T* A = get_ptr(dest);
const int lda = get_ld(dest); const int lda = get_ld(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda); if (transpose == false)
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
...@@ -937,7 +977,10 @@ namespace dlib ...@@ -937,7 +977,10 @@ namespace dlib
T* A = get_ptr(dest); T* A = get_ptr(dest);
const int lda = get_ld(dest); const int lda = get_ld(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda); if (transpose == false)
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// -------------------------------------- // --------------------------------------
...@@ -960,7 +1003,10 @@ namespace dlib ...@@ -960,7 +1003,10 @@ namespace dlib
T* A = get_ptr(dest); T* A = get_ptr(dest);
const int lda = get_ld(dest); const int lda = get_ld(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda); if (transpose == false)
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING } DLIB_END_BLAS_BINDING
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -26,7 +26,7 @@ namespace ...@@ -26,7 +26,7 @@ namespace
logger dlog("test.matrix3"); logger dlog("test.matrix3");
const double eps_mul = 200000; const double eps_mul = 500000;
template <typename T, typename U> template <typename T, typename U>
void check_equal ( void check_equal (
...@@ -101,7 +101,7 @@ namespace ...@@ -101,7 +101,7 @@ namespace
dlib::rand::float_1a rnd; dlib::rand::float_1a rnd;
matrix<type> a(rows,cols), temp, temp2; matrix<type> a(rows,cols), temp, temp2, temp3;
for (int i = 0; i < 6; ++i) for (int i = 0; i < 6; ++i)
{ {
...@@ -173,51 +173,99 @@ namespace ...@@ -173,51 +173,99 @@ namespace
// GEMM tests // GEMM tests
dlog << LTRACE << "1.1"; dlog << LTRACE << "1.1";
check_equal(tmp(at*a), at*a); check_equal(tmp(at*a), at*a);
check_equal(tmp(trans(at*a)), trans(at*a));
check_equal(tmp(2.4*trans(4*trans(at*a) + at*3*a)), 2.4*trans(4*trans(at*a) + at*3*a));
dlog << LTRACE << "1.2"; dlog << LTRACE << "1.2";
check_equal(tmp(trans(a)*a), trans(a)*a); check_equal(tmp(trans(a)*a), trans(a)*a);
check_equal(tmp(trans(trans(a)*a)), trans(trans(a)*a));
dlog << LTRACE << "1.3"; dlog << LTRACE << "1.3";
check_equal(tmp(at*trans(at)), at*trans(at)); check_equal(tmp(at*trans(at)), at*trans(at));
check_equal(tmp(trans(at*trans(at))), trans(at*trans(at)));
dlog << LTRACE << "1.4"; dlog << LTRACE << "1.4";
check_equal(tmp(trans(at)*trans(a)), a*at); check_equal(tmp(trans(at)*trans(a)), a*at);
check_equal(tmp(trans(trans(at)*trans(a))), trans(a*at));
dlog << LTRACE << "1.5"; dlog << LTRACE << "1.5";
print_spinner(); print_spinner();
c_check_equal(tmp(conj(trans(c_a))*c_a), trans(conj(c_a))*c_a); c_check_equal(tmp(conj(trans(c_a))*c_a), trans(conj(c_a))*c_a);
c_check_equal(tmp(trans(conj(trans(c_a))*c_a)), trans(trans(conj(c_a))*c_a));
dlog << LTRACE << "1.6"; dlog << LTRACE << "1.6";
c_check_equal(tmp(c_at*trans(conj(c_at))), c_at*conj(trans(c_at))); c_check_equal(tmp(c_at*trans(conj(c_at))), c_at*conj(trans(c_at)));
c_check_equal(tmp(trans(c_at*trans(conj(c_at)))), trans(c_at*conj(trans(c_at))));
dlog << LTRACE << "1.7"; dlog << LTRACE << "1.7";
c_check_equal(tmp(conj(trans(c_at))*trans(conj(c_a))), conj(trans(c_at))*trans(conj(c_a))); c_check_equal(tmp(conj(trans(c_at))*trans(conj(c_a))), conj(trans(c_at))*trans(conj(c_a)));
c_check_equal(tmp(trans(conj(trans(c_at))*trans(conj(c_a)))), trans(conj(trans(c_at))*trans(conj(c_a))));
dlog << LTRACE << "1.8"; dlog << LTRACE << "1.8";
check_equal(tmp(a*trans(rowm(a,1))) , a*trans(rowm(a,1))); check_equal(tmp(a*trans(rowm(a,1))) , a*trans(rowm(a,1)));
check_equal(tmp(a*colm(at,1)) , a*colm(at,1)); check_equal(tmp(a*colm(at,1)) , a*colm(at,1));
check_equal(tmp(subm(a,1,1,2,2)*subm(a,1,2,2,2)), subm(a,1,1,2,2)*subm(a,1,2,2,2)); check_equal(tmp(subm(a,1,1,2,2)*subm(a,1,2,2,2)), subm(a,1,1,2,2)*subm(a,1,2,2,2));
temp = at*a; dlog << LTRACE << "1.9";
temp2 = temp; check_equal(tmp(trans(a*trans(rowm(a,1)))) , trans(a*trans(rowm(a,1))));
dlog << LTRACE << "1.10";
check_equal(tmp(trans(a*colm(at,1))) , trans(a*colm(at,1)));
dlog << LTRACE << "1.11";
check_equal(tmp(trans(subm(a,1,1,2,2)*subm(a,1,2,2,2))), trans(subm(a,1,1,2,2)*subm(a,1,2,2,2)));
dlog << LTRACE << "1.12";
temp += 3.5*at*a; {
assign_no_blas(temp2, temp2 + 3.5*at*a); temp = at*a;
check_equal(temp, temp2); temp2 = temp;
temp -= at*3.5*a; temp += 3.5*at*a;
assign_no_blas(temp2, temp2 - at*3.5*a); assign_no_blas(temp2, temp2 + 3.5*at*a);
check_equal(temp, temp2); check_equal(temp, temp2);
temp = temp + 4*at*a; temp -= at*3.5*a;
assign_no_blas(temp2, temp2 + 4*at*a); assign_no_blas(temp2, temp2 - at*3.5*a);
check_equal(temp, temp2); check_equal(temp, temp2);
temp = temp - 2.4*at*a; temp = temp + 4*at*a;
assign_no_blas(temp2, temp2 - 2.4*at*a); assign_no_blas(temp2, temp2 + 4*at*a);
check_equal(temp, temp2); check_equal(temp, temp2);
temp = temp - 2.4*at*a;
assign_no_blas(temp2, temp2 - 2.4*at*a);
check_equal(temp, temp2);
}
dlog << LTRACE << "1.13";
{
temp = trans(at*a);
temp2 = temp;
temp3 = temp;
dlog << LTRACE << "1.14";
temp += trans(3.5*at*a);
assign_no_blas(temp2, temp2 + trans(3.5*at*a));
check_equal(temp, temp2);
dlog << LTRACE << "1.15";
temp -= trans(at*3.5*a);
assign_no_blas(temp2, temp2 - trans(at*3.5*a));
check_equal(temp, temp2);
dlog << LTRACE << "1.16";
temp = trans(temp + 4*at*a);
assign_no_blas(temp3, trans(temp2 + 4*at*a));
check_equal(temp, temp3);
temp2 = temp;
dlog << LTRACE << "1.17";
temp = trans(temp - 2.4*at*a);
assign_no_blas(temp3, trans(temp2 - 2.4*at*a));
check_equal(temp, temp3);
}
dlog << LTRACE << "1.18";
// GEMV tests // GEMV tests
check_equal(tmp(a*cv4), a*cv4); check_equal(tmp(a*cv4), a*cv4);
check_equal(tmp(trans(a*cv4)), trans(a*cv4));
check_equal(tmp(rv3*a), rv3*a); check_equal(tmp(rv3*a), rv3*a);
check_equal(tmp(trans(cv4)*at), trans(cv4)*at); check_equal(tmp(trans(cv4)*at), trans(cv4)*at);
check_equal(tmp(a*trans(rv4)), a*trans(rv4)); check_equal(tmp(a*trans(rv4)), a*trans(rv4));
check_equal(tmp(trans(a*trans(rv4))), trans(a*trans(rv4)));
check_equal(tmp(trans(a)*cv3), trans(a)*cv3); check_equal(tmp(trans(a)*cv3), trans(a)*cv3);
check_equal(tmp(rv4*trans(a)), rv4*trans(a)); check_equal(tmp(rv4*trans(a)), rv4*trans(a));
...@@ -291,21 +339,76 @@ namespace ...@@ -291,21 +339,76 @@ namespace
// Test BLAS GER // Test BLAS GER
temp.set_size(cols,cols); {
set_all_elements(temp,3); temp.set_size(cols,cols);
temp2 = temp; set_all_elements(temp,3);
temp2 = temp;
dlog << LTRACE << "8"; dlog << LTRACE << "8";
temp += cv4*rv4; temp += cv4*rv4;
assign_no_blas(temp2, temp2 + cv4*rv4); assign_no_blas(temp2, temp2 + cv4*rv4);
check_equal(temp, temp2); check_equal(temp, temp2);
dlog << LTRACE << "8.3"; dlog << LTRACE << "8.3";
temp = temp + cv4*rv4; temp = temp + cv4*rv4;
assign_no_blas(temp2, temp2 + cv4*rv4); assign_no_blas(temp2, temp2 + cv4*rv4);
check_equal(temp, temp2); check_equal(temp, temp2);
dlog << LTRACE << "8.9"; dlog << LTRACE << "8.9";
}
{
temp.set_size(cols,cols);
set_all_elements(temp,3);
temp2 = temp;
temp3 = 0;
dlog << LTRACE << "8.10";
temp += trans(cv4*rv4);
assign_no_blas(temp3, temp2 + trans(cv4*rv4));
check_equal(temp, temp3);
temp3 = 0;
dlog << LTRACE << "8.11";
temp2 = temp;
temp = trans(temp + cv4*rv4);
assign_no_blas(temp3, trans(temp2 + cv4*rv4));
check_equal(temp, temp3);
dlog << LTRACE << "8.12";
}
{
matrix<complex<type> > temp, temp2, temp3;
matrix<complex<type>,0,1 > cv4;
matrix<complex<type>,1,0 > rv4;
cv4.set_size(cols);
rv4.set_size(cols);
temp.set_size(cols,cols);
set_all_elements(temp,complex<type>(3,5));
temp(cols-1, cols-4) = 9;
temp2 = temp;
temp3.set_size(cols,cols);
temp3 = 0;
for (long i = 0; i < rv4.size(); ++i)
{
rv4(i) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd));
cv4(i) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd));
}
dlog << LTRACE << "8.13";
temp += trans(cv4*rv4);
assign_no_blas(temp3, temp2 + trans(cv4*rv4));
c_check_equal(temp, temp3);
temp3 = 0;
dlog << LTRACE << "8.14";
temp2 = temp;
temp = trans(temp + cv4*rv4);
assign_no_blas(temp3, trans(temp2 + cv4*rv4));
c_check_equal(temp, temp3);
dlog << LTRACE << "8.15";
}
...@@ -340,6 +443,7 @@ namespace ...@@ -340,6 +443,7 @@ namespace
// Test DOT // Test DOT
check_equal( tmp(rv4*cv4), rv4*cv4); check_equal( tmp(rv4*cv4), rv4*cv4);
check_equal( tmp(trans(rv4*cv4)), trans(rv4*cv4));
check_equal( tmp(trans(cv4)*trans(rv4)), trans(cv4)*trans(rv4)); check_equal( tmp(trans(cv4)*trans(rv4)), trans(cv4)*trans(rv4));
check_equal( tmp(rv4*3.9*cv4), rv4*3.9*cv4); check_equal( tmp(rv4*3.9*cv4), rv4*3.9*cv4);
check_equal( tmp(trans(cv4)*3.9*trans(rv4)), trans(cv4)*3.9*trans(rv4)); check_equal( tmp(trans(cv4)*3.9*trans(rv4)), trans(cv4)*3.9*trans(rv4));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment