Commit eb823114 authored by Davis King's avatar Davis King

Checking in changes to the matrix object that allow it to

factor expressions containing trans() operators.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402833
parent 0803ff8f
This diff is collapsed.
......@@ -326,8 +326,6 @@ namespace dlib
//cout << "BLAS GEMM: m*m" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
......@@ -340,7 +338,11 @@ namespace dlib
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
if (transpose == false)
cblas_gemm(Order, CblasNoTrans, CblasNoTrans, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
cblas_gemm(Order, CblasTrans, CblasTrans, M, N, K, alpha, B, ldb, A, lda, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -364,7 +366,11 @@ namespace dlib
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, B, ldb, A, lda, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -388,18 +394,20 @@ namespace dlib
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, B, ldb, A, lda, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(m)*trans(m))
{
//cout << "BLAS GEMM: trans(m)*trans(m)" << endl;
cout << "BLAS GEMM: trans(m)*trans(m)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
......@@ -412,7 +420,10 @@ namespace dlib
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
if (transpose == false)
cblas_gemm(Order, CblasTrans, CblasTrans, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
cblas_gemm(Order, CblasNoTrans, CblasNoTrans, M, N, K, alpha, B, ldb, A, lda, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -438,7 +449,11 @@ namespace dlib
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
matrix_assign_default(dest, trans(src), alpha, add_to);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -462,7 +477,10 @@ namespace dlib
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
matrix_assign_default(dest, trans(src), alpha, add_to);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -486,7 +504,10 @@ namespace dlib
T* C = get_ptr(dest);
const int ldc = get_ld(dest);
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
if (transpose == false)
cblas_gemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
else
matrix_assign_default(dest, trans(src), alpha, add_to);
} DLIB_END_BLAS_BINDING
// ----------------------------------------------------------------------------------------
......@@ -795,7 +816,11 @@ namespace dlib
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
if (transpose == false)
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -818,7 +843,10 @@ namespace dlib
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
if (transpose == false)
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -841,7 +869,10 @@ namespace dlib
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
if (transpose == false)
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -864,7 +895,10 @@ namespace dlib
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
if (transpose == false)
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_ger(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
// ----------------------------------------------------------------------------------------
......@@ -891,7 +925,10 @@ namespace dlib
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
if (transpose == false)
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -914,7 +951,10 @@ namespace dlib
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
if (transpose == false)
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -937,7 +977,10 @@ namespace dlib
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
if (transpose == false)
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
......@@ -960,7 +1003,10 @@ namespace dlib
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
if (transpose == false)
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
else
cblas_gerc(Order, M, N, alpha, Y, incY, X, incX, A, lda);
} DLIB_END_BLAS_BINDING
// ----------------------------------------------------------------------------------------
......
......@@ -26,7 +26,7 @@ namespace
logger dlog("test.matrix3");
const double eps_mul = 200000;
const double eps_mul = 500000;
template <typename T, typename U>
void check_equal (
......@@ -101,7 +101,7 @@ namespace
dlib::rand::float_1a rnd;
matrix<type> a(rows,cols), temp, temp2;
matrix<type> a(rows,cols), temp, temp2, temp3;
for (int i = 0; i < 6; ++i)
{
......@@ -173,51 +173,99 @@ namespace
// GEMM tests
dlog << LTRACE << "1.1";
check_equal(tmp(at*a), at*a);
check_equal(tmp(trans(at*a)), trans(at*a));
check_equal(tmp(2.4*trans(4*trans(at*a) + at*3*a)), 2.4*trans(4*trans(at*a) + at*3*a));
dlog << LTRACE << "1.2";
check_equal(tmp(trans(a)*a), trans(a)*a);
check_equal(tmp(trans(trans(a)*a)), trans(trans(a)*a));
dlog << LTRACE << "1.3";
check_equal(tmp(at*trans(at)), at*trans(at));
check_equal(tmp(trans(at*trans(at))), trans(at*trans(at)));
dlog << LTRACE << "1.4";
check_equal(tmp(trans(at)*trans(a)), a*at);
check_equal(tmp(trans(trans(at)*trans(a))), trans(a*at));
dlog << LTRACE << "1.5";
print_spinner();
c_check_equal(tmp(conj(trans(c_a))*c_a), trans(conj(c_a))*c_a);
c_check_equal(tmp(trans(conj(trans(c_a))*c_a)), trans(trans(conj(c_a))*c_a));
dlog << LTRACE << "1.6";
c_check_equal(tmp(c_at*trans(conj(c_at))), c_at*conj(trans(c_at)));
c_check_equal(tmp(trans(c_at*trans(conj(c_at)))), trans(c_at*conj(trans(c_at))));
dlog << LTRACE << "1.7";
c_check_equal(tmp(conj(trans(c_at))*trans(conj(c_a))), conj(trans(c_at))*trans(conj(c_a)));
c_check_equal(tmp(trans(conj(trans(c_at))*trans(conj(c_a)))), trans(conj(trans(c_at))*trans(conj(c_a))));
dlog << LTRACE << "1.8";
check_equal(tmp(a*trans(rowm(a,1))) , a*trans(rowm(a,1)));
check_equal(tmp(a*colm(at,1)) , a*colm(at,1));
check_equal(tmp(subm(a,1,1,2,2)*subm(a,1,2,2,2)), subm(a,1,1,2,2)*subm(a,1,2,2,2));
temp = at*a;
temp2 = temp;
dlog << LTRACE << "1.9";
check_equal(tmp(trans(a*trans(rowm(a,1)))) , trans(a*trans(rowm(a,1))));
dlog << LTRACE << "1.10";
check_equal(tmp(trans(a*colm(at,1))) , trans(a*colm(at,1)));
dlog << LTRACE << "1.11";
check_equal(tmp(trans(subm(a,1,1,2,2)*subm(a,1,2,2,2))), trans(subm(a,1,1,2,2)*subm(a,1,2,2,2)));
dlog << LTRACE << "1.12";
temp += 3.5*at*a;
assign_no_blas(temp2, temp2 + 3.5*at*a);
check_equal(temp, temp2);
{
temp = at*a;
temp2 = temp;
temp -= at*3.5*a;
assign_no_blas(temp2, temp2 - at*3.5*a);
check_equal(temp, temp2);
temp += 3.5*at*a;
assign_no_blas(temp2, temp2 + 3.5*at*a);
check_equal(temp, temp2);
temp = temp + 4*at*a;
assign_no_blas(temp2, temp2 + 4*at*a);
check_equal(temp, temp2);
temp -= at*3.5*a;
assign_no_blas(temp2, temp2 - at*3.5*a);
check_equal(temp, temp2);
temp = temp - 2.4*at*a;
assign_no_blas(temp2, temp2 - 2.4*at*a);
check_equal(temp, temp2);
temp = temp + 4*at*a;
assign_no_blas(temp2, temp2 + 4*at*a);
check_equal(temp, temp2);
temp = temp - 2.4*at*a;
assign_no_blas(temp2, temp2 - 2.4*at*a);
check_equal(temp, temp2);
}
dlog << LTRACE << "1.13";
{
temp = trans(at*a);
temp2 = temp;
temp3 = temp;
dlog << LTRACE << "1.14";
temp += trans(3.5*at*a);
assign_no_blas(temp2, temp2 + trans(3.5*at*a));
check_equal(temp, temp2);
dlog << LTRACE << "1.15";
temp -= trans(at*3.5*a);
assign_no_blas(temp2, temp2 - trans(at*3.5*a));
check_equal(temp, temp2);
dlog << LTRACE << "1.16";
temp = trans(temp + 4*at*a);
assign_no_blas(temp3, trans(temp2 + 4*at*a));
check_equal(temp, temp3);
temp2 = temp;
dlog << LTRACE << "1.17";
temp = trans(temp - 2.4*at*a);
assign_no_blas(temp3, trans(temp2 - 2.4*at*a));
check_equal(temp, temp3);
}
dlog << LTRACE << "1.18";
// GEMV tests
check_equal(tmp(a*cv4), a*cv4);
check_equal(tmp(trans(a*cv4)), trans(a*cv4));
check_equal(tmp(rv3*a), rv3*a);
check_equal(tmp(trans(cv4)*at), trans(cv4)*at);
check_equal(tmp(a*trans(rv4)), a*trans(rv4));
check_equal(tmp(trans(a*trans(rv4))), trans(a*trans(rv4)));
check_equal(tmp(trans(a)*cv3), trans(a)*cv3);
check_equal(tmp(rv4*trans(a)), rv4*trans(a));
......@@ -291,21 +339,76 @@ namespace
// Test BLAS GER
temp.set_size(cols,cols);
set_all_elements(temp,3);
temp2 = temp;
{
temp.set_size(cols,cols);
set_all_elements(temp,3);
temp2 = temp;
dlog << LTRACE << "8";
temp += cv4*rv4;
assign_no_blas(temp2, temp2 + cv4*rv4);
check_equal(temp, temp2);
dlog << LTRACE << "8";
temp += cv4*rv4;
assign_no_blas(temp2, temp2 + cv4*rv4);
check_equal(temp, temp2);
dlog << LTRACE << "8.3";
temp = temp + cv4*rv4;
assign_no_blas(temp2, temp2 + cv4*rv4);
check_equal(temp, temp2);
dlog << LTRACE << "8.9";
dlog << LTRACE << "8.3";
temp = temp + cv4*rv4;
assign_no_blas(temp2, temp2 + cv4*rv4);
check_equal(temp, temp2);
dlog << LTRACE << "8.9";
}
{
temp.set_size(cols,cols);
set_all_elements(temp,3);
temp2 = temp;
temp3 = 0;
dlog << LTRACE << "8.10";
temp += trans(cv4*rv4);
assign_no_blas(temp3, temp2 + trans(cv4*rv4));
check_equal(temp, temp3);
temp3 = 0;
dlog << LTRACE << "8.11";
temp2 = temp;
temp = trans(temp + cv4*rv4);
assign_no_blas(temp3, trans(temp2 + cv4*rv4));
check_equal(temp, temp3);
dlog << LTRACE << "8.12";
}
{
matrix<complex<type> > temp, temp2, temp3;
matrix<complex<type>,0,1 > cv4;
matrix<complex<type>,1,0 > rv4;
cv4.set_size(cols);
rv4.set_size(cols);
temp.set_size(cols,cols);
set_all_elements(temp,complex<type>(3,5));
temp(cols-1, cols-4) = 9;
temp2 = temp;
temp3.set_size(cols,cols);
temp3 = 0;
for (long i = 0; i < rv4.size(); ++i)
{
rv4(i) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd));
cv4(i) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd));
}
dlog << LTRACE << "8.13";
temp += trans(cv4*rv4);
assign_no_blas(temp3, temp2 + trans(cv4*rv4));
c_check_equal(temp, temp3);
temp3 = 0;
dlog << LTRACE << "8.14";
temp2 = temp;
temp = trans(temp + cv4*rv4);
assign_no_blas(temp3, trans(temp2 + cv4*rv4));
c_check_equal(temp, temp3);
dlog << LTRACE << "8.15";
}
......@@ -340,6 +443,7 @@ namespace
// Test DOT
check_equal( tmp(rv4*cv4), rv4*cv4);
check_equal( tmp(trans(rv4*cv4)), trans(rv4*cv4));
check_equal( tmp(trans(cv4)*trans(rv4)), trans(cv4)*trans(rv4));
check_equal( tmp(rv4*3.9*cv4), rv4*3.9*cv4);
check_equal( tmp(trans(cv4)*3.9*trans(rv4)), trans(cv4)*3.9*trans(rv4));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment