Commit 51666563 authored by Davis King's avatar Davis King

Added BLAS bindings for xscal and xaxpy routines.

parent c19bb9f3
......@@ -1502,9 +1502,7 @@ namespace dlib
const T a
)
{
const long size = data.nr()*data.nc();
for (long i = 0; i < size; ++i)
data(i) *= a;
*this = *this * a;
return *this;
}
......@@ -1512,9 +1510,7 @@ namespace dlib
const T a
)
{
const long size = data.nr()*data.nc();
for (long i = 0; i < size; ++i)
data(i) /= a;
*this = *this / a;
return *this;
}
......
......@@ -65,7 +65,7 @@ namespace dlib
template <typename T, bool Tb>
struct has_matrix_multiply<matrix_mul_scal_exp<T,Tb> >
{ const static bool value = has_matrix_multiply<T>::value; };
{ const static bool value = true; };
template <typename T>
struct has_matrix_multiply<matrix_div_scal_exp<T> >
......@@ -646,6 +646,22 @@ namespace dlib
// double, complex<float>, or complex<double> and the src_exp contains at least one
// matrix multiply.
template <
typename T, long NR, long NC, typename MM, typename L,
long NR2, long NC2, bool Sb
>
void matrix_assign_blas (
matrix<T,NR,NC,MM,L>& dest,
const matrix_mul_scal_exp<matrix<T,NR2,NC2,MM,L>,Sb>& src
)
{
// It's ok that we don't check for aliasing in this case because there isn't
// any complex unrolling of successive + or - operators in this expression.
matrix_assign_blas_proxy(dest,src.m,src.s,false, false);
}
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
......
......@@ -26,22 +26,42 @@ namespace dlib
int& counter_gemv();
int& counter_ger();
int& counter_dot();
int& counter_axpy();
int& counter_scal();
#define DLIB_TEST_BLAS_BINDING_GEMM ++counter_gemm();
#define DLIB_TEST_BLAS_BINDING_GEMV ++counter_gemv();
#define DLIB_TEST_BLAS_BINDING_GER ++counter_ger();
#define DLIB_TEST_BLAS_BINDING_DOT ++counter_dot();
#define DLIB_TEST_BLAS_BINDING_AXPY ++counter_axpy();
#define DLIB_TEST_BLAS_BINDING_SCAL ++counter_scal();
#else
#define DLIB_TEST_BLAS_BINDING_GEMM
#define DLIB_TEST_BLAS_BINDING_GEMV
#define DLIB_TEST_BLAS_BINDING_GER
#define DLIB_TEST_BLAS_BINDING_DOT
#define DLIB_TEST_BLAS_BINDING_AXPY
#define DLIB_TEST_BLAS_BINDING_SCAL
#endif
extern "C"
{
// Here we declare the prototypes for the CBLAS calls used by the BLAS bindings below
void cblas_saxpy(const int N, const float alpha, const float *X,
const int incX, float *Y, const int incY);
void cblas_daxpy(const int N, const double alpha, const double *X,
const int incX, double *Y, const int incY);
void cblas_caxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
void cblas_zaxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
void cblas_sscal(const int N, const float alpha, float *X, const int incX);
void cblas_dscal(const int N, const double alpha, double *X, const int incX);
void cblas_cscal(const int N, const void *alpha, void *X, const int incX);
void cblas_zscal(const int N, const void *alpha, void *X, const int incX);
void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const float alpha, const float *A,
......@@ -115,6 +135,62 @@ namespace dlib
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
inline void cblas_axpy(const int N, const float alpha, const float *X,
const int incX, float *Y, const int incY)
{
DLIB_TEST_BLAS_BINDING_AXPY;
cblas_saxpy(N, alpha, X, incX, Y, incY);
}
inline void cblas_axpy(const int N, const double alpha, const double *X,
const int incX, double *Y, const int incY)
{
DLIB_TEST_BLAS_BINDING_AXPY;
cblas_daxpy(N, alpha, X, incX, Y, incY);
}
inline void cblas_axpy(const int N, const std::complex<float>& alpha, const std::complex<float> *X,
const int incX, std::complex<float> *Y, const int incY)
{
DLIB_TEST_BLAS_BINDING_AXPY;
cblas_caxpy(N, &alpha, X, incX, Y, incY);
}
inline void cblas_axpy(const int N, const std::complex<double>& alpha, const std::complex<double> *X,
const int incX, std::complex<double> *Y, const int incY)
{
DLIB_TEST_BLAS_BINDING_AXPY;
cblas_zaxpy(N, &alpha, X, incX, Y, incY);
}
// ----------------------------------------------------------------------------------------
inline void cblas_scal(const int N, const float alpha, float *X)
{
DLIB_TEST_BLAS_BINDING_SCAL;
cblas_sscal(N, alpha, X, 1);
}
inline void cblas_scal(const int N, const double alpha, double *X)
{
DLIB_TEST_BLAS_BINDING_SCAL;
cblas_dscal(N, alpha, X, 1);
}
inline void cblas_scal(const int N, const std::complex<float>& alpha, std::complex<float> *X)
{
DLIB_TEST_BLAS_BINDING_SCAL;
cblas_cscal(N, &alpha, X, 1);
}
inline void cblas_scal(const int N, const std::complex<double>& alpha, std::complex<double> *X)
{
DLIB_TEST_BLAS_BINDING_SCAL;
cblas_zscal(N, &alpha, X, 1);
}
// ----------------------------------------------------------------------------------------
inline void cblas_gemm( const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
......@@ -460,6 +536,99 @@ namespace dlib
extern matrix<double,0,1> cv; // general column vector
extern const double s;
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// AXPY/SCAL overloads
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
DLIB_ADD_BLAS_BINDING(m)
{
if (transpose == false)
{
const int N = static_cast<int>(src.size());
if (add_to)
{
cblas_axpy(N, alpha, get_ptr(src), 1, get_ptr(dest), 1);
}
else
{
if (get_ptr(src) == get_ptr(dest))
{
cblas_scal(N, alpha, get_ptr(dest));
}
else
{
matrix_assign_default(dest, src, alpha, add_to);
}
}
}
else
{
matrix_assign_default(dest, trans(src), alpha, add_to);
}
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(rv)
{
if (transpose == false)
{
const int N = static_cast<int>(src.size());
if (add_to)
{
cblas_axpy(N, alpha, get_ptr(src), 1, get_ptr(dest), 1);
}
else
{
if (get_ptr(src) == get_ptr(dest))
{
cblas_scal(N, alpha, get_ptr(dest));
}
else
{
matrix_assign_default(dest, src, alpha, add_to);
}
}
}
else
{
matrix_assign_default(dest, trans(src), alpha, add_to);
}
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(cv)
{
if (transpose == false)
{
const int N = static_cast<int>(src.size());
if (add_to)
{
cblas_axpy(N, alpha, get_ptr(src), 1, get_ptr(dest), 1);
}
else
{
if (get_ptr(src) == get_ptr(dest))
{
cblas_scal(N, alpha, get_ptr(dest));
}
else
{
matrix_assign_default(dest, src, alpha, add_to);
}
}
}
else
{
matrix_assign_default(dest, trans(src), alpha, add_to);
}
} DLIB_END_BLAS_BINDING
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// GEMM overloads
......
......@@ -17,6 +17,7 @@ set (tests
blas_bindings_gemv.cpp
blas_bindings_ger.cpp
blas_bindings_dot.cpp
blas_bindings_scal_axpy.cpp
vector.cpp
)
......
// Copyright (C) 2009 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include "../tester.h"
#include <dlib/matrix.h>
#ifndef DLIB_USE_BLAS
#error "BLAS bindings must be used for this test to make any sense"
#endif
namespace dlib
{
namespace blas_bindings
{
// This is a little screwy. This function is used inside the BLAS
// bindings to count how many times each of the BLAS functions get called.
#ifdef DLIB_TEST_BLAS_BINDINGS
int& counter_axpy() { static int counter = 0; return counter; }
int& counter_scal() { static int counter = 0; return counter; }
#endif
}
}
namespace
{
using namespace test;
using namespace std;
// Declare the logger we will use in this test. The name of the logger
// should start with "test."
dlib::logger dlog("test.scal_axpy");
class blas_bindings_scal_axpy_tester : public tester
{
public:
blas_bindings_scal_axpy_tester (
) :
tester (
"test_scal_axpy", // the command line argument name for this test
"Run tests for DOT routines.", // the command line argument description
0 // the number of command line arguments for this test
)
{}
template <typename matrix_type, typename cv_type, typename rv_type>
void test_scal_axpy_stuff(
matrix_type& m,
rv_type& rv,
cv_type& cv
) const
{
using namespace dlib;
using namespace dlib::blas_bindings;
rv_type rv2 = rv;
cv_type cv2 = cv;
matrix_type m2 = m;
typedef typename matrix_type::type scalar_type;
scalar_type val;
counter_scal() = 0;
m = 5*m;
DLIB_TEST(counter_scal() == 1);
counter_scal() = 0;
rv = 5*rv;
DLIB_TEST(counter_scal() == 1);
counter_scal() = 0;
rv = 5*rv;
DLIB_TEST(counter_scal() == 1);
counter_axpy() = 0;
m2 += 5*m;
DLIB_TEST(counter_axpy() == 1);
counter_axpy() = 0;
rv2 += 5*rv;
DLIB_TEST(counter_axpy() == 1);
counter_axpy() = 0;
rv2 += 5*rv;
DLIB_TEST(counter_axpy() == 1);
counter_scal() = 0;
m = m*5;
DLIB_TEST(counter_scal() == 1);
counter_scal() = 0;
rv = rv*5;
DLIB_TEST(counter_scal() == 1);
counter_scal() = 0;
cv = cv*5;
DLIB_TEST(counter_scal() == 1);
counter_axpy() = 0;
m2 += m*5;
DLIB_TEST(counter_axpy() == 1);
counter_axpy() = 0;
rv2 += rv*5;
DLIB_TEST(counter_axpy() == 1);
counter_axpy() = 0;
cv2 += cv*5;
DLIB_TEST(counter_axpy() == 1);
counter_axpy() = 0;
m2 = m2 + m*5;
DLIB_TEST(counter_axpy() == 1);
counter_axpy() = 0;
rv2 = rv2 + rv*5;
DLIB_TEST(counter_axpy() == 1);
counter_axpy() = 0;
cv2 = cv2 + cv*5;
DLIB_TEST(counter_axpy() == 1);
counter_axpy() = 0;
cv2 = 1;
cv = 1;
cv2 = 2*cv2 + cv*5;
DLIB_TEST(counter_axpy() == 1);
DLIB_TEST(max(abs(cv2 - 7)) == 0);
counter_axpy() = 0;
rv2 = 1;
rv = 1;
rv2 = 2*rv2 + rv*5;
DLIB_TEST(counter_axpy() == 1);
DLIB_TEST(max(abs(rv2 - 7)) == 0);
counter_axpy() = 0;
m2 = 1;
m = 1;
m2 = 2*m2 + m*5;
DLIB_TEST(counter_axpy() == 1);
DLIB_TEST(max(abs(m2 - 7)) == 0);
}
void perform_test (
)
{
using namespace dlib;
typedef dlib::memory_manager<char>::kernel_1a mm;
dlog << dlib::LINFO << "test double";
{
matrix<double> m = randm(4,4);
matrix<double,1,0> rv = randm(1,4);
matrix<double,0,1> cv = randm(4,1);
test_scal_axpy_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test float";
{
matrix<float> m = matrix_cast<float>(randm(4,4));
matrix<float,1,0> rv = matrix_cast<float>(randm(1,4));
matrix<float,0,1> cv = matrix_cast<float>(randm(4,1));
test_scal_axpy_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test complex<double>";
{
matrix<complex<double> > m = complex_matrix(randm(4,4), randm(4,4));
matrix<complex<double>,1,0> rv = complex_matrix(randm(1,4), randm(1,4));
matrix<complex<double>,0,1> cv = complex_matrix(randm(4,1), randm(4,1));
test_scal_axpy_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test complex<float>";
{
matrix<complex<float> > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
matrix<complex<float>,1,0> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
matrix<complex<float>,0,1> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
test_scal_axpy_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test double, column major";
{
matrix<double,0,0,mm,column_major_layout> m = randm(4,4);
matrix<double,1,0,mm,column_major_layout> rv = randm(1,4);
matrix<double,0,1,mm,column_major_layout> cv = randm(4,1);
test_scal_axpy_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test float, column major";
{
matrix<float,0,0,mm,column_major_layout> m = matrix_cast<float>(randm(4,4));
matrix<float,1,0,mm,column_major_layout> rv = matrix_cast<float>(randm(1,4));
matrix<float,0,1,mm,column_major_layout> cv = matrix_cast<float>(randm(4,1));
test_scal_axpy_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test complex<double>, column major";
{
matrix<complex<double>,0,0,mm,column_major_layout > m = complex_matrix(randm(4,4), randm(4,4));
matrix<complex<double>,1,0,mm,column_major_layout> rv = complex_matrix(randm(1,4), randm(1,4));
matrix<complex<double>,0,1,mm,column_major_layout> cv = complex_matrix(randm(4,1), randm(4,1));
test_scal_axpy_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test complex<float>, column major";
{
matrix<complex<float>,0,0,mm,column_major_layout > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
matrix<complex<float>,1,0,mm,column_major_layout> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
matrix<complex<float>,0,1,mm,column_major_layout> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
test_scal_axpy_stuff(m,rv,cv);
}
print_spinner();
}
};
blas_bindings_scal_axpy_tester a;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment