Commit 8dbe2979 authored by Davis King's avatar Davis King

Changed the test code so that it executes faster and also prints the spinner more often.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403302
parent e127f5f8
...@@ -44,13 +44,14 @@ namespace ...@@ -44,13 +44,14 @@ namespace
template <typename matrix_type> template <typename matrix_type>
void test_gemm_stuff( void test_gemm_stuff(
matrix_type& a const matrix_type& c
) const ) const
{ {
using namespace dlib; using namespace dlib;
using namespace dlib::blas_bindings; using namespace dlib::blas_bindings;
matrix_type b; matrix_type b, a;
a = c;
counter_gemm() = 0; counter_gemm() = 0;
b = a*a; b = a*a;
...@@ -71,41 +72,49 @@ namespace ...@@ -71,41 +72,49 @@ namespace
counter_gemm() = 0; counter_gemm() = 0;
b = a*a*a*a; b = a*a*a*a;
DLIB_TEST(counter_gemm() == 3); DLIB_TEST(counter_gemm() == 3);
b = c;
counter_gemm() = 0; counter_gemm() = 0;
a = a*a*a*a; a = a*a*a*a;
DLIB_TEST(counter_gemm() == 3); DLIB_TEST(counter_gemm() == 3);
a = c;
counter_gemm() = 0; counter_gemm() = 0;
a = (b + a*trans(a)*a*3*a)*trans(b); a = (b + a*trans(a)*a*3*a)*trans(b);
DLIB_TEST(counter_gemm() == 4); DLIB_TEST(counter_gemm() == 4);
a = c;
counter_gemm() = 0; counter_gemm() = 0;
a = trans((trans(b) + trans(a)*trans(a)*a*3*a)*trans(b)); a = trans((trans(b) + trans(a)*trans(a)*a*3*a)*trans(b));
DLIB_TEST(counter_gemm() == 4); DLIB_TEST(counter_gemm() == 4);
a = c;
counter_gemm() = 0; counter_gemm() = 0;
a = trans((trans(b) + trans(a)*(a)*trans(a)*3*a)*trans(b)); a = trans((trans(b) + trans(a)*(a)*trans(a)*3*a)*trans(b));
DLIB_TEST(counter_gemm() == 4); DLIB_TEST(counter_gemm() == 4);
a = c;
counter_gemm() = 0; counter_gemm() = 0;
a = trans((trans(b) + trans(a)*(a + b)*trans(a)*3*a)*trans(b)); a = trans((trans(b) + trans(a)*(a + b)*trans(a)*3*a)*trans(b));
DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm()); DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm());
a = c;
counter_gemm() = 0; counter_gemm() = 0;
a = trans((trans(b) + trans(a)*(a*8 + b+b+b+b)*trans(a)*3*a)*trans(b)); a = trans((trans(b) + trans(a)*(a*8 + b+b+b+b)*trans(a)*3*a)*trans(b));
DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm()); DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm());
a = c;
} }
template <typename matrix_type> template <typename matrix_type>
void test_gemm_stuff_conj( void test_gemm_stuff_conj(
matrix_type& a const matrix_type& c
) const ) const
{ {
using namespace dlib; using namespace dlib;
using namespace dlib::blas_bindings; using namespace dlib::blas_bindings;
matrix_type b; matrix_type b, a;
a = c;
counter_gemm() = 0; counter_gemm() = 0;
b = a*conj(a); b = a*conj(a);
...@@ -126,30 +135,37 @@ namespace ...@@ -126,30 +135,37 @@ namespace
counter_gemm() = 0; counter_gemm() = 0;
b = a*a*conj(a)*a; b = a*a*conj(a)*a;
DLIB_TEST(counter_gemm() == 3); DLIB_TEST(counter_gemm() == 3);
b = c;
counter_gemm() = 0; counter_gemm() = 0;
a = a*trans(conj(a))*a*a; a = a*trans(conj(a))*a*a;
DLIB_TEST(counter_gemm() == 3); DLIB_TEST(counter_gemm() == 3);
a = c;
counter_gemm() = 0; counter_gemm() = 0;
a = (b + a*trans(conj(a))*a*3*a)*trans(b); a = (b + a*trans(conj(a))*a*3*a)*trans(b);
DLIB_TEST(counter_gemm() == 4); DLIB_TEST(counter_gemm() == 4);
a = c;
counter_gemm() = 0; counter_gemm() = 0;
a = (trans((conj(trans(b)) + trans(a)*conj(trans(a))*a*3*a)*trans(b))); a = (trans((conj(trans(b)) + trans(a)*conj(trans(a))*a*3*a)*trans(b)));
DLIB_TEST(counter_gemm() == 4); DLIB_TEST(counter_gemm() == 4);
a = c;
counter_gemm() = 0; counter_gemm() = 0;
a = ((trans(b) + trans(a)*(a)*trans(a)*3*a)*trans(conj(b))); a = ((trans(b) + trans(a)*(a)*trans(a)*3*a)*trans(conj(b)));
DLIB_TEST(counter_gemm() == 4); DLIB_TEST(counter_gemm() == 4);
a = c;
counter_gemm() = 0; counter_gemm() = 0;
a = trans((trans(b) + trans(a)*conj(a + b)*trans(a)*3*a)*trans(b)); a = trans((trans(b) + trans(a)*conj(a + b)*trans(a)*3*a)*trans(b));
DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm()); DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm());
a = c;
counter_gemm() = 0; counter_gemm() = 0;
a = trans((trans(b) + trans(a)*(a*8 + b+b+b+b)*trans(a)*3*conj(a))*trans(b)); a = trans((trans(b) + trans(a)*(a*8 + b+b+b+b)*trans(a)*3*conj(a))*trans(b));
DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm()); DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm());
a = c;
} }
void perform_test ( void perform_test (
...@@ -158,18 +174,22 @@ namespace ...@@ -158,18 +174,22 @@ namespace
using namespace dlib; using namespace dlib;
typedef dlib::memory_manager<char>::kernel_1a mm; typedef dlib::memory_manager<char>::kernel_1a mm;
print_spinner();
dlog << dlib::LINFO << "test double"; dlog << dlib::LINFO << "test double";
{ {
matrix<double> a = randm(4,4); matrix<double> a = randm(4,4);
test_gemm_stuff(a); test_gemm_stuff(a);
} }
print_spinner();
dlog << dlib::LINFO << "test float"; dlog << dlib::LINFO << "test float";
{ {
matrix<float> a = matrix_cast<float>(randm(4,4)); matrix<float> a = matrix_cast<float>(randm(4,4));
test_gemm_stuff(a); test_gemm_stuff(a);
} }
print_spinner();
dlog << dlib::LINFO << "test complex<float>"; dlog << dlib::LINFO << "test complex<float>";
{ {
matrix<float> a = matrix_cast<float>(randm(4,4)); matrix<float> a = matrix_cast<float>(randm(4,4));
...@@ -179,6 +199,7 @@ namespace ...@@ -179,6 +199,7 @@ namespace
test_gemm_stuff_conj(c); test_gemm_stuff_conj(c);
} }
print_spinner();
dlog << dlib::LINFO << "test complex<double>"; dlog << dlib::LINFO << "test complex<double>";
{ {
matrix<double> a = matrix_cast<double>(randm(4,4)); matrix<double> a = matrix_cast<double>(randm(4,4));
...@@ -189,6 +210,7 @@ namespace ...@@ -189,6 +210,7 @@ namespace
} }
print_spinner();
dlog << dlib::LINFO << "test double, column major"; dlog << dlib::LINFO << "test double, column major";
{ {
...@@ -196,12 +218,14 @@ namespace ...@@ -196,12 +218,14 @@ namespace
test_gemm_stuff(a); test_gemm_stuff(a);
} }
print_spinner();
dlog << dlib::LINFO << "test float, column major"; dlog << dlib::LINFO << "test float, column major";
{ {
matrix<float,100,100,mm,column_major_layout> a = matrix_cast<float>(randm(100,100)); matrix<float,100,100,mm,column_major_layout> a = matrix_cast<float>(randm(100,100));
test_gemm_stuff(a); test_gemm_stuff(a);
} }
print_spinner();
dlog << dlib::LINFO << "test complex<double>, column major"; dlog << dlib::LINFO << "test complex<double>, column major";
{ {
matrix<double,100,100,mm,column_major_layout> a = matrix_cast<double>(randm(100,100)); matrix<double,100,100,mm,column_major_layout> a = matrix_cast<double>(randm(100,100));
...@@ -211,6 +235,7 @@ namespace ...@@ -211,6 +235,7 @@ namespace
test_gemm_stuff_conj(c); test_gemm_stuff_conj(c);
} }
print_spinner();
dlog << dlib::LINFO << "test complex<float>, column major"; dlog << dlib::LINFO << "test complex<float>, column major";
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment