Commit dfa5d579 authored by Davis King's avatar Davis King

Added a set of regression tests that make sure the BLAS bindings are actually getting

called when they are supposed to be.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403284
parent 812fd55c
#
# This is a CMake makefile. You can find the cmake utility and
# information about it at http://www.cmake.org
#
# setting this makes CMake allow normal looking IF ELSE statements
SET(CMAKE_ALLOW_LOOSE_LOOP_CONSTRUCTS true)
cmake_minimum_required(VERSION 2.4)
if(COMMAND cmake_policy)
cmake_policy(SET CMP0003 NEW)
endif(COMMAND cmake_policy)
# This variable contains a list of all the tests we are building
# into the regression test suite.
set (tests
blas_bindings_gemm.cpp
blas_bindings_gemv.cpp
blas_bindings_ger.cpp
blas_bindings_dot.cpp
)
# create a variable called target_name and set it to the string "test"
set (target_name test)
PROJECT(${target_name})
# add all the cpp files we want to compile to this list. This tells
# cmake that they are part of our target (which is the executable named test)
ADD_EXECUTABLE(${target_name} ../main.cpp ../tester.cpp ${tests})
# add the folder containing the dlib folder to the include path
INCLUDE_DIRECTORIES(../../..)
ADD_DEFINITIONS(-DDLIB_TEST_BLAS_BINDINGS)
# There is a CMakeLists.txt file in the dlib source folder that tells cmake
# how to build the dlib library. Tell cmake about that file.
add_subdirectory(../.. dlib_build)
# Tell cmake to link our target executable to dlib
TARGET_LINK_LIBRARIES(${target_name} dlib )
// Copyright (C) 2009 Davis E. King (davisking@users.sourceforge.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include "../tester.h"
#include <dlib/matrix.h>
#ifndef DLIB_USE_BLAS
#error "BLAS bindings must be used for this test to make any sense"
#endif
namespace dlib
{
namespace blas_bindings
{
// This is a little screwy. This function is used inside the BLAS
// bindings to count how many times each of the BLAS functions get called.
#ifdef DLIB_TEST_BLAS_BINDINGS
int& counter_dot() { static int counter = 0; return counter; }
#endif
}
}
namespace
{
using namespace test;
using namespace std;
// Declare the logger we will use in this test. The name of the logger
// should start with "test."
dlib::logger dlog("test.dot");
class blas_bindings_dot_tester : public tester
{
public:
blas_bindings_dot_tester (
) :
tester (
"test_dot", // the command line argument name for this test
"Run example tests.", // the command line argument description
0 // the number of command line arguments for this test
)
{}
template <typename matrix_type, typename cv_type, typename rv_type>
void test_dot_stuff(
matrix_type& m,
rv_type& rv,
cv_type& cv
) const
{
using namespace dlib;
using namespace dlib::blas_bindings;
rv_type rv2;
cv_type cv2;
matrix_type m2;
typedef typename matrix_type::type scalar_type;
scalar_type val;
counter_dot() = 0;
m2 = rv*cv;
DLIB_TEST(counter_dot() == 1);
counter_dot() = 0;
val = rv*cv;
DLIB_TEST(counter_dot() == 1);
counter_dot() = 0;
val = rv*3*cv;
DLIB_TEST(counter_dot() == 1);
counter_dot() = 0;
val = rv*trans(rv)*3;
DLIB_TEST(counter_dot() == 1);
counter_dot() = 0;
val = trans(rv*trans(rv)*3 + trans(cv)*cv);
DLIB_TEST(counter_dot() == 2);
counter_dot() = 0;
val = trans(cv)*cv;
DLIB_TEST(counter_dot() == 1);
counter_dot() = 0;
val = trans(cv)*trans(rv);
DLIB_TEST(counter_dot() == 1);
counter_dot() = 0;
val = dot(rv,cv);
DLIB_TEST(counter_dot() == 1);
counter_dot() = 0;
val = dot(cv,cv);
DLIB_TEST(counter_dot() == 1);
counter_dot() = 0;
val = dot(rv,rv);
DLIB_TEST(counter_dot() == 1);
counter_dot() = 0;
val = dot(rv,trans(rv));
DLIB_TEST(counter_dot() == 1);
counter_dot() = 0;
val = dot(trans(cv),cv);
DLIB_TEST(counter_dot() == 1);
counter_dot() = 0;
val = dot(trans(cv),trans(rv));
DLIB_TEST(counter_dot() == 1);
// This does one dot and one gemv
counter_dot() = 0;
val = trans(cv)*m*trans(rv);
DLIB_TEST_MSG(counter_dot() == 1, counter_dot());
// This does one dot and two gemv
counter_dot() = 0;
val = (trans(cv)*m)*(m*trans(rv));
DLIB_TEST_MSG(counter_dot() == 1, counter_dot());
// This does one dot and two gemv
counter_dot() = 0;
val = trans(cv)*m*trans(m)*trans(rv);
DLIB_TEST_MSG(counter_dot() == 1, counter_dot());
}
template <typename matrix_type, typename cv_type, typename rv_type>
void test_dot_stuff_conj(
matrix_type& m,
rv_type& rv,
cv_type& cv
) const
{
using namespace dlib;
using namespace dlib::blas_bindings;
rv_type rv2;
cv_type cv2;
matrix_type m2;
typedef typename matrix_type::type scalar_type;
scalar_type val;
counter_dot() = 0;
val = conj(rv)*cv;
DLIB_TEST(counter_dot() == 1);
counter_dot() = 0;
val = trans(conj(cv))*cv;
DLIB_TEST(counter_dot() == 1);
counter_dot() = 0;
val = trans(conj(cv))*trans(rv);
DLIB_TEST(counter_dot() == 1);
counter_dot() = 0;
val = trans(conj(cv))*3*trans(rv);
DLIB_TEST(counter_dot() == 1);
}
void perform_test (
)
{
using namespace dlib;
typedef dlib::memory_manager<char>::kernel_1a mm;
dlog << dlib::LINFO << "test double";
{
matrix<double> m = randm(4,4);
matrix<double,1,0> rv = randm(1,4);
matrix<double,0,1> cv = randm(4,1);
test_dot_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test float";
{
matrix<float> m = matrix_cast<float>(randm(4,4));
matrix<float,1,0> rv = matrix_cast<float>(randm(1,4));
matrix<float,0,1> cv = matrix_cast<float>(randm(4,1));
test_dot_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test complex<double>";
{
matrix<complex<double> > m = complex_matrix(randm(4,4), randm(4,4));
matrix<complex<double>,1,0> rv = complex_matrix(randm(1,4), randm(1,4));
matrix<complex<double>,0,1> cv = complex_matrix(randm(4,1), randm(4,1));
test_dot_stuff(m,rv,cv);
test_dot_stuff_conj(m,rv,cv);
}
dlog << dlib::LINFO << "test complex<float>";
{
matrix<complex<float> > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
matrix<complex<float>,1,0> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
matrix<complex<float>,0,1> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
test_dot_stuff(m,rv,cv);
test_dot_stuff_conj(m,rv,cv);
}
dlog << dlib::LINFO << "test double, column major";
{
matrix<double,0,0,mm,column_major_layout> m = randm(4,4);
matrix<double,1,0,mm,column_major_layout> rv = randm(1,4);
matrix<double,0,1,mm,column_major_layout> cv = randm(4,1);
test_dot_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test float, column major";
{
matrix<float,0,0,mm,column_major_layout> m = matrix_cast<float>(randm(4,4));
matrix<float,1,0,mm,column_major_layout> rv = matrix_cast<float>(randm(1,4));
matrix<float,0,1,mm,column_major_layout> cv = matrix_cast<float>(randm(4,1));
test_dot_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test complex<double>, column major";
{
matrix<complex<double>,0,0,mm,column_major_layout > m = complex_matrix(randm(4,4), randm(4,4));
matrix<complex<double>,1,0,mm,column_major_layout> rv = complex_matrix(randm(1,4), randm(1,4));
matrix<complex<double>,0,1,mm,column_major_layout> cv = complex_matrix(randm(4,1), randm(4,1));
test_dot_stuff(m,rv,cv);
test_dot_stuff_conj(m,rv,cv);
}
dlog << dlib::LINFO << "test complex<float>, column major";
{
matrix<complex<float>,0,0,mm,column_major_layout > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
matrix<complex<float>,1,0,mm,column_major_layout> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
matrix<complex<float>,0,1,mm,column_major_layout> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
test_dot_stuff(m,rv,cv);
test_dot_stuff_conj(m,rv,cv);
}
print_spinner();
}
};
blas_bindings_dot_tester a;
}
// Copyright (C) 2009 Davis E. King (davisking@users.sourceforge.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include "../tester.h"
#include <dlib/matrix.h>
#ifndef DLIB_USE_BLAS
#error "BLAS bindings must be used for this test to make any sense"
#endif
namespace dlib
{
namespace blas_bindings
{
// This is a little screwy. This function is used inside the BLAS
// bindings to count how many times each of the BLAS functions get called.
#ifdef DLIB_TEST_BLAS_BINDINGS
int& counter_gemm() { static int counter = 0; return counter; }
#endif
}
}
namespace
{
using namespace test;
using namespace std;
// Declare the logger we will use in this test. The name of the logger
// should start with "test."
dlib::logger dlog("test.gemm");
class blas_bindings_gemm_tester : public tester
{
public:
blas_bindings_gemm_tester (
) :
tester (
"test_gemm", // the command line argument name for this test
"Run example tests.", // the command line argument description
0 // the number of command line arguments for this test
)
{}
template <typename matrix_type>
void test_gemm_stuff(
matrix_type& a
) const
{
using namespace dlib;
using namespace dlib::blas_bindings;
matrix_type b;
counter_gemm() = 0;
b = a*a;
DLIB_TEST(counter_gemm() == 1);
counter_gemm() = 0;
b = a*trans(a) + a;
DLIB_TEST(counter_gemm() == 1);
counter_gemm() = 0;
b = trans(a)*trans(a) + a;
DLIB_TEST(counter_gemm() == 1);
counter_gemm() = 0;
b = trans(trans(trans(a)*a + a));
DLIB_TEST(counter_gemm() == 1);
counter_gemm() = 0;
b = a*a*a*a;
DLIB_TEST(counter_gemm() == 3);
counter_gemm() = 0;
a = a*a*a*a;
DLIB_TEST(counter_gemm() == 3);
counter_gemm() = 0;
a = (b + a*trans(a)*a*3*a)*trans(b);
DLIB_TEST(counter_gemm() == 4);
counter_gemm() = 0;
a = trans((trans(b) + trans(a)*trans(a)*a*3*a)*trans(b));
DLIB_TEST(counter_gemm() == 4);
counter_gemm() = 0;
a = trans((trans(b) + trans(a)*(a)*trans(a)*3*a)*trans(b));
DLIB_TEST(counter_gemm() == 4);
counter_gemm() = 0;
a = trans((trans(b) + trans(a)*(a + b)*trans(a)*3*a)*trans(b));
DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm());
counter_gemm() = 0;
a = trans((trans(b) + trans(a)*(a*8 + b+b+b+b)*trans(a)*3*a)*trans(b));
DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm());
}
template <typename matrix_type>
void test_gemm_stuff_conj(
matrix_type& a
) const
{
using namespace dlib;
using namespace dlib::blas_bindings;
matrix_type b;
counter_gemm() = 0;
b = a*conj(a);
DLIB_TEST(counter_gemm() == 1);
counter_gemm() = 0;
b = a*trans(conj(a)) + a;
DLIB_TEST(counter_gemm() == 1);
counter_gemm() = 0;
b = conj(trans(a))*trans(a) + a;
DLIB_TEST(counter_gemm() == 1);
counter_gemm() = 0;
b = trans(trans(trans(a)*conj(a) + conj(a)));
DLIB_TEST(counter_gemm() == 1);
counter_gemm() = 0;
b = a*a*conj(a)*a;
DLIB_TEST(counter_gemm() == 3);
counter_gemm() = 0;
a = a*trans(conj(a))*a*a;
DLIB_TEST(counter_gemm() == 3);
counter_gemm() = 0;
a = (b + a*trans(conj(a))*a*3*a)*trans(b);
DLIB_TEST(counter_gemm() == 4);
counter_gemm() = 0;
a = (trans((conj(trans(b)) + trans(a)*conj(trans(a))*a*3*a)*trans(b)));
DLIB_TEST(counter_gemm() == 4);
counter_gemm() = 0;
a = ((trans(b) + trans(a)*(a)*trans(a)*3*a)*trans(conj(b)));
DLIB_TEST(counter_gemm() == 4);
counter_gemm() = 0;
a = trans((trans(b) + trans(a)*conj(a + b)*trans(a)*3*a)*trans(b));
DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm());
counter_gemm() = 0;
a = trans((trans(b) + trans(a)*(a*8 + b+b+b+b)*trans(a)*3*conj(a))*trans(b));
DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm());
}
void perform_test (
)
{
using namespace dlib;
typedef dlib::memory_manager<char>::kernel_1a mm;
dlog << dlib::LINFO << "test double";
{
matrix<double> a = randm(4,4);
test_gemm_stuff(a);
}
dlog << dlib::LINFO << "test float";
{
matrix<float> a = matrix_cast<float>(randm(4,4));
test_gemm_stuff(a);
}
dlog << dlib::LINFO << "test complex<float>";
{
matrix<float> a = matrix_cast<float>(randm(4,4));
matrix<float> b = matrix_cast<float>(randm(4,4));
matrix<complex<float> > c = complex_matrix(a,b);
test_gemm_stuff(c);
test_gemm_stuff_conj(c);
}
dlog << dlib::LINFO << "test complex<double>";
{
matrix<double> a = matrix_cast<double>(randm(4,4));
matrix<double> b = matrix_cast<double>(randm(4,4));
matrix<complex<double> > c = complex_matrix(a,b);
test_gemm_stuff(c);
test_gemm_stuff_conj(c);
}
dlog << dlib::LINFO << "test double, column major";
{
matrix<double,100,100,mm,column_major_layout> a = randm(100,100);
test_gemm_stuff(a);
}
dlog << dlib::LINFO << "test float, column major";
{
matrix<float,100,100,mm,column_major_layout> a = matrix_cast<float>(randm(100,100));
test_gemm_stuff(a);
}
dlog << dlib::LINFO << "test complex<double>, column major";
{
matrix<double,100,100,mm,column_major_layout> a = matrix_cast<double>(randm(100,100));
matrix<double,100,100,mm,column_major_layout> b = matrix_cast<double>(randm(100,100));
matrix<complex<double>,100,100,mm,column_major_layout > c = complex_matrix(a,b);
test_gemm_stuff(c);
test_gemm_stuff_conj(c);
}
dlog << dlib::LINFO << "test complex<float>, column major";
{
matrix<float,100,100,mm,column_major_layout> a = matrix_cast<float>(randm(100,100));
matrix<float,100,100,mm,column_major_layout> b = matrix_cast<float>(randm(100,100));
matrix<complex<float>,100,100,mm,column_major_layout > c = complex_matrix(a,b);
test_gemm_stuff(c);
test_gemm_stuff_conj(c);
}
print_spinner();
}
};
blas_bindings_gemm_tester a;
}
// Copyright (C) 2009 Davis E. King (davisking@users.sourceforge.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include "../tester.h"
#include <dlib/matrix.h>
#ifndef DLIB_USE_BLAS
#error "BLAS bindings must be used for this test to make any sense"
#endif
namespace dlib
{
namespace blas_bindings
{
// This is a little screwy. This function is used inside the BLAS
// bindings to count how many times each of the BLAS functions get called.
#ifdef DLIB_TEST_BLAS_BINDINGS
int& counter_gemv() { static int counter = 0; return counter; }
#endif
}
}
namespace
{
using namespace test;
using namespace std;
// Declare the logger we will use in this test. The name of the logger
// should start with "test."
dlib::logger dlog("test.gemv");
class blas_bindings_gemv_tester : public tester
{
public:
blas_bindings_gemv_tester (
) :
tester (
"test_gemv", // the command line argument name for this test
"Run example tests.", // the command line argument description
0 // the number of command line arguments for this test
)
{}
template <typename matrix_type, typename rv_type, typename cv_type>
void test_gemv_stuff(
matrix_type& m,
cv_type& cv,
rv_type& rv
) const
{
using namespace dlib;
using namespace dlib::blas_bindings;
cv_type cv2;
rv_type rv2;
typedef typename matrix_type::type scalar_type;
scalar_type val;
counter_gemv() = 0;
cv2 = m*cv;
DLIB_TEST(counter_gemv() == 1);
counter_gemv() = 0;
cv2 = m*2*cv;
DLIB_TEST(counter_gemv() == 1);
counter_gemv() = 0;
cv2 = m*2*trans(rv);
DLIB_TEST(counter_gemv() == 1);
counter_gemv() = 0;
rv2 = trans(m*2*cv);
DLIB_TEST(counter_gemv() == 1);
counter_gemv() = 0;
rv2 = rv*m;
DLIB_TEST(counter_gemv() == 1);
counter_gemv() = 0;
rv2 = (rv + rv)*m;
DLIB_TEST(counter_gemv() == 1);
counter_gemv() = 0;
rv2 = trans(cv)*m;
DLIB_TEST(counter_gemv() == 1);
dlog << dlib::LTRACE << 1;
counter_gemv() = 0;
rv2 = trans(cv)*trans(m) + rv*trans(m);
DLIB_TEST(counter_gemv() == 2);
dlog << dlib::LTRACE << 2;
counter_gemv() = 0;
cv2 = m*trans(trans(cv)*trans(m) + 3*rv*trans(m));
DLIB_TEST(counter_gemv() == 3);
// This does one dot and one gemv
counter_gemv() = 0;
val = trans(cv)*m*trans(rv);
DLIB_TEST_MSG(counter_gemv() == 1, counter_gemv());
// This does one dot and two gemv
counter_gemv() = 0;
val = (trans(cv)*m)*(m*trans(rv));
DLIB_TEST_MSG(counter_gemv() == 2, counter_gemv());
// This does one dot and two gemv
counter_gemv() = 0;
val = trans(cv)*m*trans(m)*trans(rv);
DLIB_TEST_MSG(counter_gemv() == 2, counter_gemv());
}
template <typename matrix_type, typename rv_type, typename cv_type>
void test_gemv_stuff_conj(
matrix_type& m,
cv_type& cv,
rv_type& rv
) const
{
using namespace dlib;
using namespace dlib::blas_bindings;
cv_type cv2;
rv_type rv2;
counter_gemv() = 0;
cv2 = trans(cv)*conj(m);
DLIB_TEST(counter_gemv() == 1);
counter_gemv() = 0;
cv2 = conj(trans(m))*rv;
DLIB_TEST(counter_gemv() == 1);
counter_gemv() = 0;
cv2 = conj(trans(m))*trans(cv);
DLIB_TEST(counter_gemv() == 1);
counter_gemv() = 0;
cv2 = trans(trans(cv)*conj(2*m) + conj(3*trans(m))*rv + conj(trans(m)*3)*trans(cv));
DLIB_TEST(counter_gemv() == 3);
}
void perform_test (
)
{
using namespace dlib;
typedef dlib::memory_manager<char>::kernel_1a mm;
dlog << dlib::LINFO << "test double";
{
matrix<double> m = randm(4,4);
matrix<double,0,1> cv = randm(4,1);
matrix<double,1,0> rv = randm(1,4);
test_gemv_stuff(m,cv,rv);
}
dlog << dlib::LINFO << "test float";
{
matrix<float> m = matrix_cast<float>(randm(4,4));
matrix<float,0,1> cv = matrix_cast<float>(randm(4,1));
matrix<float,1,0> rv = matrix_cast<float>(randm(1,4));
test_gemv_stuff(m,cv,rv);
}
dlog << dlib::LINFO << "test complex<double>";
{
matrix<complex<double> > m = complex_matrix(randm(4,4), randm(4,4));
matrix<complex<double>,0,1> cv = complex_matrix(randm(4,1), randm(4,1));
matrix<complex<double>,1,0> rv = complex_matrix(randm(1,4), randm(1,4));
test_gemv_stuff(m,cv,rv);
}
dlog << dlib::LINFO << "test complex<float>";
{
matrix<complex<float> > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
matrix<complex<float>,0,1> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
matrix<complex<float>,1,0> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
test_gemv_stuff(m,cv,rv);
}
dlog << dlib::LINFO << "test double";
{
matrix<double,0,0,mm,column_major_layout> m = randm(4,4);
matrix<double,0,1,mm,column_major_layout> cv = randm(4,1);
matrix<double,1,0,mm,column_major_layout> rv = randm(1,4);
test_gemv_stuff(m,cv,rv);
}
dlog << dlib::LINFO << "test float";
{
matrix<float,0,0,mm,column_major_layout> m = matrix_cast<float>(randm(4,4));
matrix<float,0,1,mm,column_major_layout> cv = matrix_cast<float>(randm(4,1));
matrix<float,1,0,mm,column_major_layout> rv = matrix_cast<float>(randm(1,4));
test_gemv_stuff(m,cv,rv);
}
dlog << dlib::LINFO << "test complex<double>";
{
matrix<complex<double>,0,0,mm,column_major_layout > m = complex_matrix(randm(4,4), randm(4,4));
matrix<complex<double>,0,1,mm,column_major_layout> cv = complex_matrix(randm(4,1), randm(4,1));
matrix<complex<double>,1,0,mm,column_major_layout> rv = complex_matrix(randm(1,4), randm(1,4));
test_gemv_stuff(m,cv,rv);
}
dlog << dlib::LINFO << "test complex<float>";
{
matrix<complex<float>,0,0,mm,column_major_layout > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
matrix<complex<float>,0,1,mm,column_major_layout> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
matrix<complex<float>,1,0,mm,column_major_layout> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
test_gemv_stuff(m,cv,rv);
}
print_spinner();
}
};
blas_bindings_gemv_tester a;
}
// Copyright (C) 2009 Davis E. King (davisking@users.sourceforge.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include "../tester.h"
#include <dlib/matrix.h>
#ifndef DLIB_USE_BLAS
#error "BLAS bindings must be used for this test to make any sense"
#endif
namespace dlib
{
namespace blas_bindings
{
// This is a little screwy. This function is used inside the BLAS
// bindings to count how many times each of the BLAS functions get called.
#ifdef DLIB_TEST_BLAS_BINDINGS
int& counter_ger() { static int counter = 0; return counter; }
#endif
}
}
namespace
{
using namespace test;
using namespace std;
// Declare the logger we will use in this test. The name of the logger
// should start with "test."
dlib::logger dlog("test.ger");
class blas_bindings_ger_tester : public tester
{
public:
blas_bindings_ger_tester (
) :
tester (
"test_ger", // the command line argument name for this test
"Run example tests.", // the command line argument description
0 // the number of command line arguments for this test
)
{}
template <typename matrix_type, typename cv_type, typename rv_type>
void test_ger_stuff(
matrix_type& m,
rv_type& rv,
cv_type& cv
) const
{
using namespace dlib;
using namespace dlib::blas_bindings;
rv_type rv2;
cv_type cv2;
matrix_type m2;
counter_ger() = 0;
m2 = m + cv*rv;
DLIB_TEST_MSG(counter_ger() == 1, counter_ger());
counter_ger() = 0;
m += trans(rv)*rv;
DLIB_TEST(counter_ger() == 1);
counter_ger() = 0;
m += trans(rv)*trans(cv);
DLIB_TEST(counter_ger() == 1);
counter_ger() = 0;
m += cv*trans(cv);
DLIB_TEST(counter_ger() == 1);
counter_ger() = 0;
m += trans(rv)*rv + trans(cv*3*rv);
DLIB_TEST(counter_ger() == 2);
}
template <typename matrix_type, typename cv_type, typename rv_type>
void test_ger_stuff_conj(
matrix_type& m,
rv_type& rv,
cv_type& cv
) const
{
using namespace dlib;
using namespace dlib::blas_bindings;
rv_type rv2;
cv_type cv2;
matrix_type m2;
counter_ger() = 0;
m += cv*conj(rv);
DLIB_TEST_MSG(counter_ger() == 1, counter_ger());
counter_ger() = 0;
m += trans(rv)*conj(rv);
DLIB_TEST(counter_ger() == 1);
counter_ger() = 0;
m += trans(rv)*conj(trans(cv));
DLIB_TEST(counter_ger() == 1);
counter_ger() = 0;
m += cv*trans(conj(cv));
DLIB_TEST(counter_ger() == 1);
counter_ger() = 0;
m += trans(rv)*rv + trans(cv*3*conj(rv));
DLIB_TEST(counter_ger() == 2);
}
void perform_test (
)
{
using namespace dlib;
typedef dlib::memory_manager<char>::kernel_1a mm;
dlog << dlib::LINFO << "test double";
{
matrix<double> m = randm(4,4);
matrix<double,1,0> rv = randm(1,4);
matrix<double,0,1> cv = randm(4,1);
test_ger_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test float";
{
matrix<float> m = matrix_cast<float>(randm(4,4));
matrix<float,1,0> rv = matrix_cast<float>(randm(1,4));
matrix<float,0,1> cv = matrix_cast<float>(randm(4,1));
test_ger_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test complex<double>";
{
matrix<complex<double> > m = complex_matrix(randm(4,4), randm(4,4));
matrix<complex<double>,1,0> rv = complex_matrix(randm(1,4), randm(1,4));
matrix<complex<double>,0,1> cv = complex_matrix(randm(4,1), randm(4,1));
test_ger_stuff(m,rv,cv);
test_ger_stuff_conj(m,rv,cv);
}
dlog << dlib::LINFO << "test complex<float>";
{
matrix<complex<float> > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
matrix<complex<float>,1,0> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
matrix<complex<float>,0,1> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
test_ger_stuff(m,rv,cv);
test_ger_stuff_conj(m,rv,cv);
}
dlog << dlib::LINFO << "test double";
{
matrix<double,0,0,mm,column_major_layout> m = randm(4,4);
matrix<double,1,0,mm,column_major_layout> rv = randm(1,4);
matrix<double,0,1,mm,column_major_layout> cv = randm(4,1);
test_ger_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test float";
{
matrix<float,0,0,mm,column_major_layout> m = matrix_cast<float>(randm(4,4));
matrix<float,1,0,mm,column_major_layout> rv = matrix_cast<float>(randm(1,4));
matrix<float,0,1,mm,column_major_layout> cv = matrix_cast<float>(randm(4,1));
test_ger_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test complex<double>";
{
matrix<complex<double>,0,0,mm,column_major_layout > m = complex_matrix(randm(4,4), randm(4,4));
matrix<complex<double>,1,0,mm,column_major_layout> rv = complex_matrix(randm(1,4), randm(1,4));
matrix<complex<double>,0,1,mm,column_major_layout> cv = complex_matrix(randm(4,1), randm(4,1));
test_ger_stuff(m,rv,cv);
test_ger_stuff_conj(m,rv,cv);
}
dlog << dlib::LINFO << "test complex<float>";
{
matrix<complex<float>,0,0,mm,column_major_layout > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
matrix<complex<float>,1,0,mm,column_major_layout> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
matrix<complex<float>,0,1,mm,column_major_layout> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
test_ger_stuff(m,rv,cv);
test_ger_stuff_conj(m,rv,cv);
}
print_spinner();
}
};
blas_bindings_ger_tester a;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment