Commit dea6bc04 authored by Davis King's avatar Davis King

- Added overloads to cause scalar multiplications to combine and percolate out

     of matrix multiplications.
   - Worked more on the optimized overloads that call BLAS functions.
   - Changed the code inside the matrix assignment overloads so that it works
     better with GCC's optimizer.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402761
parent d346bdcf
#ifndef CBLAS_H
#define CBLAS_H
#include <stddef.h>
/* Allow the use in C++ code. */
#ifdef __cplusplus
extern "C"
{
#endif
/*
* Enumerated and derived types
*/
#define CBLAS_INDEX size_t /* this may vary between platforms */
enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};
/*
* ===========================================================================
* Prototypes for level 1 BLAS functions (complex are recast as routines)
* ===========================================================================
*/
float cblas_sdsdot(const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY);
double cblas_dsdot(const int N, const float *X, const int incX, const float *Y,
const int incY);
float cblas_sdot(const int N, const float *X, const int incX,
const float *Y, const int incY);
double cblas_ddot(const int N, const double *X, const int incX,
const double *Y, const int incY);
/*
* Functions having prefixes Z and C only
*/
void cblas_cdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_cdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
void cblas_zdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_zdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
/*
* Functions having prefixes S D SC DZ
*/
float cblas_snrm2(const int N, const float *X, const int incX);
float cblas_sasum(const int N, const float *X, const int incX);
double cblas_dnrm2(const int N, const double *X, const int incX);
double cblas_dasum(const int N, const double *X, const int incX);
float cblas_scnrm2(const int N, const void *X, const int incX);
float cblas_scasum(const int N, const void *X, const int incX);
double cblas_dznrm2(const int N, const void *X, const int incX);
double cblas_dzasum(const int N, const void *X, const int incX);
/*
* Functions having standard 4 prefixes (S D C Z)
*/
CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX);
CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX);
CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX);
CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 1 BLAS routines
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (s, d, c, z)
*/
void cblas_sswap(const int N, float *X, const int incX,
float *Y, const int incY);
void cblas_scopy(const int N, const float *X, const int incX,
float *Y, const int incY);
void cblas_saxpy(const int N, const float alpha, const float *X,
const int incX, float *Y, const int incY);
void cblas_dswap(const int N, double *X, const int incX,
double *Y, const int incY);
void cblas_dcopy(const int N, const double *X, const int incX,
double *Y, const int incY);
void cblas_daxpy(const int N, const double alpha, const double *X,
const int incX, double *Y, const int incY);
void cblas_cswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_ccopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_caxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
void cblas_zswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_zcopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_zaxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
/*
* Routines with S and D prefix only
*/
void cblas_srotg(float *a, float *b, float *c, float *s);
void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P);
void cblas_srot(const int N, float *X, const int incX,
float *Y, const int incY, const float c, const float s);
void cblas_srotm(const int N, float *X, const int incX,
float *Y, const int incY, const float *P);
void cblas_drotg(double *a, double *b, double *c, double *s);
void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P);
void cblas_drot(const int N, double *X, const int incX,
double *Y, const int incY, const double c, const double s);
void cblas_drotm(const int N, double *X, const int incX,
double *Y, const int incY, const double *P);
/*
* Routines with S D C Z CS and ZD prefixes
*/
void cblas_sscal(const int N, const float alpha, float *X, const int incX);
void cblas_dscal(const int N, const double alpha, double *X, const int incX);
void cblas_cscal(const int N, const void *alpha, void *X, const int incX);
void cblas_zscal(const int N, const void *alpha, void *X, const int incX);
void cblas_csscal(const int N, const float alpha, void *X, const int incX);
void cblas_zdscal(const int N, const double alpha, void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 2 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *X, const int incX, const float beta,
float *Y, const int incY);
void cblas_sgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const float alpha,
const float *A, const int lda, const float *X,
const int incX, const float beta, float *Y, const int incY);
void cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda,
float *X, const int incX);
void cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_strsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda, float *X,
const int incX);
void cblas_stbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_dgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *X, const int incX, const double beta,
double *Y, const int incY);
void cblas_dgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const double alpha,
const double *A, const int lda, const double *X,
const int incX, const double beta, double *Y, const int incY);
void cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda,
double *X, const int incX);
void cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_dtrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda, double *X,
const int incX);
void cblas_dtbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_cgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_cgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ctrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ctbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_zgemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_zgbmv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ztrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ztbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
/*
* Routines with S and D prefixes only
*/
void cblas_ssymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_ssbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *Ap,
const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sger(const enum CBLAS_ORDER order, const int M, const int N,
const float alpha, const float *X, const int incX,
const float *Y, const int incY, float *A, const int lda);
void cblas_ssyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *A, const int lda);
void cblas_sspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *Ap);
void cblas_ssyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A,
const int lda);
void cblas_sspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A);
void cblas_dsymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dsbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *Ap,
const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dger(const enum CBLAS_ORDER order, const int M, const int N,
const double alpha, const double *X, const int incX,
const double *Y, const int incY, double *A, const int lda);
void cblas_dsyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *A, const int lda);
void cblas_dspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *Ap);
void cblas_dsyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A,
const int lda);
void cblas_dspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A);
/*
* Routines with C and Z prefixes only
*/
void cblas_chemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_cgeru(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cgerc(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_chpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X,
const int incX, void *A);
void cblas_cher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_chpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
void cblas_zhemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zgeru(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zgerc(const enum CBLAS_ORDER order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_zhpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X,
const int incX, void *A);
void cblas_zher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zhpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
/*
* ===========================================================================
* Prototypes for level 3 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const float alpha, const float *A,
const int lda, const float *B, const int ldb,
const float beta, float *C, const int ldc);
void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float beta, float *C, const int ldc);
void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const double alpha, const double *A,
const int lda, const double *B, const int ldb,
const double beta, double *C, const int ldc);
void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double beta, double *C, const int ldc);
void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
/*
* Routines with prefixes C and Z only
*/
void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const void *A, const int lda,
const float beta, void *C, const int ldc);
void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const float beta,
void *C, const int ldc);
void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const void *A, const int lda,
const double beta, void *C, const int ldc);
void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const double beta,
void *C, const int ldc);
void cblas_xerbla(int p, const char *rout, const char *form, ...);
#ifdef __cplusplus
}
#endif
#endif
......@@ -308,6 +308,8 @@ namespace dlib
const static bool lhs_is_costly = matrix_traits<matrix_multiply_exp>::lhs_is_costly;
const static bool rhs_is_costly = matrix_traits<matrix_multiply_exp>::rhs_is_costly;
const static bool either_is_costly = lhs_is_costly || rhs_is_costly;
const static bool both_are_costly = lhs_is_costly && rhs_is_costly;
typedef typename conditional_matrix_temp<const LHS,lhs_is_costly == false>::type LHS_ref_type;
typedef typename conditional_matrix_temp<const RHS,rhs_is_costly == false>::type RHS_ref_type;
......@@ -387,6 +389,59 @@ namespace dlib
return matrix_multiply_exp<EXP1, EXP2>(m1.ref(), m2.ref());
}
template <typename M, bool use_reference = true>
class matrix_mul_scal_exp;
// -------------------------
// Now we declare some overloads that cause any scalar multiplications to percolate
// up and outside of any matrix multiplies. Note that we are using the non-reference containing
// mode of the matrix_mul_scal_exp object since we are passing in locally constructed matrix_multiply_exp
// objects. So the matrix_mul_scal_exp object will contain copies of matrix_multiply_exp objects
// rather than references to them. This could result in extra matrix copies if the matrix_multiply_exp
// decided it should evaluate any of its arguments. So we also try to not apply this percolating operation
// if the matrix_multiply_exp would contain a fully evaluated copy of the original matrix_mul_scal_exp
// expression.
//
// Also, the reason we want to apply this transformation in the first place is because it (1) makes
// the expressions going into matrix multiply expressions simpler and (2) it makes it a lot more
// straight forward to bind BLAS calls to matrix expressions involving scalar multiplies.
template < typename EXP1, typename EXP2 >
inline const typename disable_if_c< matrix_multiply_exp<matrix_mul_scal_exp<EXP1>, matrix_mul_scal_exp<EXP2> >::both_are_costly ,
matrix_mul_scal_exp<matrix_multiply_exp<EXP1, EXP2>,false> >::type operator* (
const matrix_mul_scal_exp<EXP1>& m1,
const matrix_mul_scal_exp<EXP2>& m2
)
{
typedef matrix_multiply_exp<EXP1, EXP2> exp1;
typedef matrix_mul_scal_exp<exp1,false> exp2;
return exp2(exp1(m1.m, m2.m), m1.s*m2.s);
}
template < typename EXP1, typename EXP2 >
inline const typename disable_if_c< matrix_multiply_exp<matrix_mul_scal_exp<EXP1>, EXP2 >::lhs_is_costly ,
matrix_mul_scal_exp<matrix_multiply_exp<EXP1, EXP2>,false> >::type operator* (
const matrix_mul_scal_exp<EXP1>& m1,
const matrix_exp<EXP2>& m2
)
{
typedef matrix_multiply_exp<EXP1, EXP2> exp1;
typedef matrix_mul_scal_exp<exp1,false> exp2;
return exp2(exp1(m1.m, m2.ref()), m1.s);
}
template < typename EXP1, typename EXP2 >
inline const typename disable_if_c< matrix_multiply_exp<EXP1, matrix_mul_scal_exp<EXP2> >::rhs_is_costly ,
matrix_mul_scal_exp<matrix_multiply_exp<EXP1, EXP2>,false> >::type operator* (
const matrix_exp<EXP1>& m1,
const matrix_mul_scal_exp<EXP2>& m2
)
{
typedef matrix_multiply_exp<EXP1, EXP2> exp1;
typedef matrix_mul_scal_exp<exp1,false> exp2;
return exp2(exp1(m1.ref(), m2.m), m2.s);
}
// ----------------------------------------------------------------------------------------
template <typename LHS, typename RHS>
......@@ -689,21 +744,18 @@ namespace dlib
typename EXP,
typename S
>
inline const matrix_div_scal_exp<EXP> operator/ (
inline const typename enable_if_c<std::numeric_limits<typename EXP::type>::is_integer, matrix_div_scal_exp<EXP> >::type operator/ (
const matrix_exp<EXP>& m,
const S& s
)
{
return matrix_div_scal_exp<EXP>(m.ref(),s);
return matrix_div_scal_exp<EXP>(m.ref(),static_cast<typename EXP::type>(s));
}
// ----------------------------------------------------------------------------------------
template <typename M>
class matrix_mul_scal_exp;
template <typename M>
struct matrix_traits<matrix_mul_scal_exp<M> >
template <typename M, bool use_reference >
struct matrix_traits<matrix_mul_scal_exp<M,use_reference> >
{
typedef typename M::type type;
typedef typename M::mem_manager_type mem_manager_type;
......@@ -713,10 +765,15 @@ namespace dlib
const static long cost = M::cost+1;
};
template <typename T, bool is_ref> struct conditional_reference { typedef T type; };
template <typename T> struct conditional_reference<T,true> { typedef T& type; };
template <
typename M
typename M,
bool use_reference
>
class matrix_mul_scal_exp : public matrix_exp<matrix_mul_scal_exp<M> >
class matrix_mul_scal_exp : public matrix_exp<matrix_mul_scal_exp<M,use_reference> >
{
/*!
REQUIREMENTS ON M
......@@ -773,7 +830,9 @@ namespace dlib
long nc (
) const { return m.nc(); }
const M& m;
typedef typename conditional_reference<const M,use_reference>::type M_ref_type;
M_ref_type m;
const type s;
};
......@@ -789,6 +848,19 @@ namespace dlib
return matrix_mul_scal_exp<EXP>(m.ref(),s);
}
template <
typename EXP,
typename S,
bool B
>
inline typename disable_if<is_matrix<S>, const matrix_mul_scal_exp<EXP> >::type operator* (
const matrix_mul_scal_exp<EXP,B>& m,
const S& s
)
{
return matrix_mul_scal_exp<EXP>(m.m,s*m.s);
}
template <
typename EXP,
typename S
......@@ -802,36 +874,44 @@ namespace dlib
}
template <
typename EXP
typename EXP,
typename S,
bool B
>
inline const matrix_mul_scal_exp<EXP> operator/ (
const matrix_exp<EXP>& m,
const float& s
inline typename disable_if<is_matrix<S>, const matrix_mul_scal_exp<EXP> >::type operator* (
const S& s,
const matrix_mul_scal_exp<EXP,B>& m
)
{
return matrix_mul_scal_exp<EXP>(m.ref(),1.0f/s);
return matrix_mul_scal_exp<EXP>(m.m,s*m.s);
}
template <
typename EXP
typename EXP ,
typename S
>
inline const matrix_mul_scal_exp<EXP> operator/ (
inline const typename disable_if_c<std::numeric_limits<typename EXP::type>::is_integer, matrix_mul_scal_exp<EXP> >::type operator/ (
const matrix_exp<EXP>& m,
const double& s
const S& s
)
{
return matrix_mul_scal_exp<EXP>(m.ref(),1.0/s);
typedef typename EXP::type type;
const type one = 1;
return matrix_mul_scal_exp<EXP>(m.ref(),one/static_cast<type>(s));
}
template <
typename EXP
typename EXP,
bool B,
typename S
>
inline const matrix_mul_scal_exp<EXP> operator/ (
const matrix_exp<EXP>& m,
const long double& s
inline const typename disable_if_c<std::numeric_limits<typename EXP::type>::is_integer, matrix_mul_scal_exp<EXP> >::type operator/ (
const matrix_mul_scal_exp<EXP,B>& m,
const S& s
)
{
return matrix_mul_scal_exp<EXP>(m.ref(),1.0/s);
typedef typename EXP::type type;
return matrix_mul_scal_exp<EXP>(m.m,m.s/static_cast<type>(s));
}
template <
......@@ -844,6 +924,17 @@ namespace dlib
return matrix_mul_scal_exp<EXP>(m.ref(),-1);
}
template <
typename EXP,
bool B
>
inline const matrix_mul_scal_exp<EXP> operator- (
const matrix_mul_scal_exp<EXP,B>& m
)
{
return matrix_mul_scal_exp<EXP>(m.m,-1*m.s);
}
// ----------------------------------------------------------------------------------------
template <
......@@ -1260,10 +1351,20 @@ namespace dlib
COMPILE_TIME_ASSERT((is_same_type<typename EXP::type,type>::value == true) ||
(is_matrix<typename EXP::type>::value == true));
if (m.destructively_aliases(*this) == false)
{
// This if statement is seemingly unnecessary since set_size() contains this
// exact same if statement. However, structuring the code this way causes
// gcc to handle the way it inlines this function in a much more favorable way.
if (data.nr() == m.nr() && data.nc() == m.nc())
{
matrix_assign(*this, m);
}
else
{
set_size(m.nr(),m.nc());
matrix_assign(*this, m);
}
}
else
{
// we have to use a temporary matrix object here because
......
......@@ -69,7 +69,7 @@ namespace dlib
struct same_exp<matrix_subtract_exp<Tlhs,Trhs>, matrix_subtract_exp<Ulhs,Urhs> >
{ const static bool value = same_exp<Tlhs,Ulhs>::value && same_exp<Trhs,Urhs>::value; };
template <typename T, typename U> struct same_exp<matrix_mul_scal_exp<T>, matrix_mul_scal_exp<U> >
template <typename T, typename U, bool Tb, bool Ub> struct same_exp<matrix_mul_scal_exp<T,Tb>, matrix_mul_scal_exp<U,Ub> >
{ const static bool value = same_exp<T,U>::value; };
template <typename T, typename U> struct same_exp<matrix_div_scal_exp<T>, matrix_div_scal_exp<U> >
......@@ -113,13 +113,7 @@ namespace dlib
const EXP& src
)
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) = src(r,c);
}
}
matrix_assign_default(dest,src);
}
// If we know this is a matrix multiply then apply the
......@@ -134,6 +128,75 @@ namespace dlib
set_all_elements(dest,0);
default_matrix_multiply(dest, src.lhs, src.rhs);
}
template <typename EXP1, typename EXP2>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L>, matrix_multiply_exp<EXP1,EXP2> >& src
)
{
if (&dest == &src.lhs)
{
default_matrix_multiply(dest, src.rhs.lhs, src.rhs.rhs);
}
else
{
dest = src.lhs;
default_matrix_multiply(dest, src.rhs.lhs, src.rhs.rhs);
}
}
template <typename EXP1, typename EXP2>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L>, matrix_add_exp<EXP1,EXP2> >& src
)
{
if (EXP1::cost > 50 || EXP2::cost > 5)
{
matrix_assign(dest, src.lhs + src.rhs.lhs);
matrix_assign(dest, src.lhs + src.rhs.rhs);
}
else
{
matrix_assign_default(dest,src);
}
}
template <typename EXP2>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<matrix<T,NR,NC,MM,L>,EXP2>& src
)
{
if (EXP2::cost > 50 && &dest != &src.lhs)
{
dest = src.lhs;
matrix_assign(dest, dest + src.rhs);
}
else
{
matrix_assign_default(dest,src);
}
}
template <typename EXP1, typename EXP2>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const matrix_add_exp<EXP1,EXP2>& src
)
{
if (EXP1::cost > 50 || EXP2::cost > 50)
{
matrix_assign(dest,src.lhs);
matrix_assign(dest, dest + src.rhs);
}
else
{
matrix_assign_default(dest,src);
}
}
};
// This is a macro to help us add overloads for the matrix_assign_blas_helper template.
......
......@@ -39,27 +39,10 @@ namespace dlib
// ----------------------------------------------------------------------------------------
// In newer versions of GCC it is necessary to explicitly tell it to not try to
// inline the matrix_assign() function when working with matrix objects that
// don't have dimensions that are known at compile time. Doing this makes the
// resulting binaries a lot faster when -O3 is used. This whole deal with
// different versions of matrix_assign() is just to support getting the right
// inline behavior out of GCC.
#ifdef __GNUC__
#define DLIB_DONT_INLINE __attribute__((noinline))
#define DLIB_ALWAYS_INLINE __attribute__((always_inline))
#else
#define DLIB_DONT_INLINE
#define DLIB_ALWAYS_INLINE
#endif
template <
typename matrix_dest_type,
typename src_exp
>
DLIB_DONT_INLINE void matrix_assign_big (
matrix_dest_type& dest,
const matrix_exp<src_exp>& src
template <typename EXP1, typename EXP2>
inline static void matrix_assign_default (
EXP1& dest,
const EXP2& src
)
{
for (long r = 0; r < src.nr(); ++r)
......@@ -71,22 +54,18 @@ namespace dlib
}
}
// ----------------------------------------------------------------------------------------
template <
typename matrix_dest_type,
typename src_exp
>
inline void matrix_assign_small (
void matrix_assign_big (
matrix_dest_type& dest,
const matrix_exp<src_exp>& src
)
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r,c) = src(r,c);
}
}
matrix_assign_default(dest,src);
}
// ----------------------------------------------------------------------------------------
......@@ -131,7 +110,7 @@ namespace dlib
- the part of dest outside the above sub matrix remains unchanged
!*/
{
matrix_assign_small(dest,src.ref());
matrix_assign_default(dest,src.ref());
}
// ----------------------------------------------------------------------------------------
......
......@@ -6,7 +6,7 @@
#include "matrix_assign.h"
#ifdef DLIB_FOUND_BLAS
#include "mkl_cblas.h"
#include "cblas.h"
#endif
namespace dlib
......@@ -32,9 +32,9 @@ namespace dlib
extern matrix<double,0,0,mm,column_major_layout> cm; // general matrix with column major order
extern matrix<double,1,0> rv; // general row vector
extern matrix<double,0,1> cv; // general column vector
extern const double s;
using namespace std;
#ifdef DLIB_FOUND_BLAS
......@@ -59,6 +59,33 @@ namespace dlib
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(double, row_major_layout,rm + rm*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const double alpha = 1;
const double* A = &src.rhs.lhs(0,0);
const int lda = src.rhs.lhs.nc();
const double* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const double beta = 1;
double* C = &dest(0,0);
const int ldc = src.rhs.nc();
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(double, row_major_layout, trans(rm)*rm)
{
......@@ -81,8 +108,84 @@ namespace dlib
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(double, row_major_layout, rm + s*trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc());
const int K = static_cast<int>(src.rhs.m.lhs.nc());
const double alpha = src.rhs.s;
const double* A = &src.rhs.m.lhs.m(0,0);
const int lda = src.rhs.m.lhs.m.nc();
const double* B = &src.rhs.m.rhs(0,0);
const int ldb = src.rhs.m.rhs.nc();
const double beta = 1;
double* C = &dest(0,0);
const int ldc = dest.nc();
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(double, row_major_layout, s*trans(rm)*rm)
{
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.m.nr());
const int N = static_cast<int>(src.m.nc());
const int K = static_cast<int>(src.m.lhs.nc());
const double alpha = src.s;
const double* A = &src.m.lhs.m(0,0);
const int lda = src.m.lhs.m.nc();
const double* B = &src.m.rhs(0,0);
const int ldb = src.m.rhs.nc();
const double beta = 0;
double* C = &dest(0,0);
const int ldc = dest.nc();
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(double, row_major_layout, rm + trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const double alpha = 1;
const double* A = &src.rhs.lhs.m(0,0);
const int lda = src.rhs.lhs.m.nc();
const double* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const double beta = 1;
double* C = &dest(0,0);
const int ldc = src.nc();
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// ---------------------------------------------------------------------
// -------------------------- float overloads --------------------------
// ---------------------------------------------------------------------
DLIB_ADD_BLAS_BINDING(float, row_major_layout, rm*rm)
{
......@@ -105,6 +208,33 @@ namespace dlib
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(float, row_major_layout,rm + rm*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const float alpha = 1;
const float* A = &src.rhs.lhs(0,0);
const int lda = src.rhs.lhs.nc();
const float* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const float beta = 1;
float* C = &dest(0,0);
const int ldc = src.rhs.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(float, row_major_layout, trans(rm)*rm)
{
......@@ -127,6 +257,80 @@ namespace dlib
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(float, row_major_layout, rm + s*trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.m.nr());
const int N = static_cast<int>(src.rhs.m.nc());
const int K = static_cast<int>(src.rhs.m.lhs.nc());
const float alpha = src.rhs.s;
const float* A = &src.rhs.m.lhs.m(0,0);
const int lda = src.rhs.m.lhs.m.nc();
const float* B = &src.rhs.m.rhs(0,0);
const int ldb = src.rhs.m.rhs.nc();
const float beta = 1;
float* C = &dest(0,0);
const int ldc = dest.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(float, row_major_layout, s*trans(rm)*rm)
{
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.m.nr());
const int N = static_cast<int>(src.m.nc());
const int K = static_cast<int>(src.m.lhs.nc());
const float alpha = src.s;
const float* A = &src.m.lhs.m(0,0);
const int lda = src.m.lhs.m.nc();
const float* B = &src.m.rhs(0,0);
const int ldb = src.m.rhs.nc();
const float beta = 0;
float* C = &dest(0,0);
const int ldc = dest.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
DLIB_ADD_BLAS_BINDING(float, row_major_layout, rm + trans(rm)*rm)
{
if (&src.lhs != &dest)
{
dest = src.lhs;
}
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.rhs.nr());
const int N = static_cast<int>(src.rhs.nc());
const int K = static_cast<int>(src.rhs.lhs.nc());
const float alpha = 1;
const float* A = &src.rhs.lhs.m(0,0);
const int lda = src.rhs.lhs.m.nc();
const float* B = &src.rhs.rhs(0,0);
const int ldb = src.rhs.rhs.nc();
const float beta = 1;
float* C = &dest(0,0);
const int ldc = src.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
} DLIB_END_BLAS_BINDING
#endif // DLIB_FOUND_BLAS
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment