Commit 715ca0da authored by Davis King's avatar Davis King

Added more BLAS bindings.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402777
parent de22808a
...@@ -237,7 +237,7 @@ namespace dlib ...@@ -237,7 +237,7 @@ namespace dlib
// At some point I need to improve the default (i.e. non BLAS) matrix // At some point I need to improve the default (i.e. non BLAS) matrix
// multiplication algorithm... // multiplication algorithm...
if (alpha == 1) if (alpha == static_cast<typename src_exp::type>(1))
{ {
if (add_to) if (add_to)
{ {
...@@ -283,8 +283,7 @@ namespace dlib ...@@ -283,8 +283,7 @@ namespace dlib
typename src_exp::type alpha, \ typename src_exp::type alpha, \
bool add_to \ bool add_to \
) { \ ) { \
typedef typename dest_exp::type T; \ typedef typename dest_exp::type T;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
#define DLIB_END_BLAS_BINDING }}; #define DLIB_END_BLAS_BINDING }};
...@@ -455,7 +454,7 @@ namespace dlib ...@@ -455,7 +454,7 @@ namespace dlib
bool add_to bool add_to
) )
{ {
if (src_exp::cost > 9 || src_exp2::cost > 9) if (has_matrix_multiply<src_exp>::value || has_matrix_multiply<src_exp2>::value)
{ {
matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to); matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to);
matrix_assign_blas_proxy(dest, src.rhs, alpha, true); matrix_assign_blas_proxy(dest, src.rhs, alpha, true);
...@@ -495,7 +494,8 @@ namespace dlib ...@@ -495,7 +494,8 @@ namespace dlib
bool add_to bool add_to
) )
{ {
if (src_exp::cost > 9 || src_exp2::cost > 9)
if (has_matrix_multiply<src_exp>::value || has_matrix_multiply<src_exp2>::value)
{ {
matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to); matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to);
matrix_assign_blas_proxy(dest, src.rhs, -alpha, true); matrix_assign_blas_proxy(dest, src.rhs, -alpha, true);
...@@ -614,27 +614,20 @@ namespace dlib ...@@ -614,27 +614,20 @@ namespace dlib
const matrix_add_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src const matrix_add_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
) )
{ {
if (src_exp::cost > 5) if (src.rhs.aliases(dest) == false)
{ {
if (src.rhs.aliases(dest) == false) if (&src.lhs != &dest)
{ {
if (&src.lhs != &dest) dest = src.lhs;
{
dest = src.lhs;
}
matrix_assign_blas_proxy(dest, src.rhs, 1, true);
}
else
{
matrix<T,NR,NC,MM,L> temp(src.lhs);
matrix_assign_blas_proxy(temp, src.rhs, 1, true);
temp.swap(dest);
} }
matrix_assign_blas_proxy(dest, src.rhs, 1, true);
} }
else else
{ {
matrix_assign_default(dest,src); matrix<T,NR,NC,MM,L> temp(src.lhs);
matrix_assign_blas_proxy(temp, src.rhs, 1, true);
temp.swap(dest);
} }
} }
...@@ -667,27 +660,20 @@ namespace dlib ...@@ -667,27 +660,20 @@ namespace dlib
const matrix_subtract_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src const matrix_subtract_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src
) )
{ {
if (src_exp::cost > 5) if (src.rhs.aliases(dest) == false)
{ {
if (src.rhs.aliases(dest) == false) if (&src.lhs != &dest)
{ {
if (&src.lhs != &dest) dest = src.lhs;
{
dest = src.lhs;
}
matrix_assign_blas_proxy(dest, src.rhs, -1, true);
}
else
{
matrix<T,NR,NC,MM,L> temp(src.lhs);
matrix_assign_blas_proxy(temp, src.rhs, -1, true);
temp.swap(dest);
} }
matrix_assign_blas_proxy(dest, src.rhs, -1, true);
} }
else else
{ {
matrix_assign_default(dest,src); matrix<T,NR,NC,MM,L> temp(src.lhs);
matrix_assign_blas_proxy(temp, src.rhs, -1, true);
temp.swap(dest);
} }
} }
......
...@@ -28,7 +28,7 @@ namespace dlib ...@@ -28,7 +28,7 @@ namespace dlib
struct is_small_matrix { static const bool value = false; }; struct is_small_matrix { static const bool value = false; };
template < typename EXP > template < typename EXP >
struct is_small_matrix<EXP, typename enable_if_c<EXP::NR>=1 && EXP::NC>=1 && struct is_small_matrix<EXP, typename enable_if_c<EXP::NR>=1 && EXP::NC>=1 &&
EXP::NR<=100 && EXP::NC<=100>::type > { static const bool value = true; }; EXP::NR<=100 && EXP::NC<=100 && (EXP::cost < 70)>::type> { static const bool value = true; };
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -114,7 +114,7 @@ namespace dlib ...@@ -114,7 +114,7 @@ namespace dlib
{ {
if (add_to) if (add_to)
{ {
if (alpha == 1) if (alpha == static_cast<typename EXP2::type>(1))
{ {
for (long r = 0; r < src.nr(); ++r) for (long r = 0; r < src.nr(); ++r)
{ {
...@@ -124,7 +124,7 @@ namespace dlib ...@@ -124,7 +124,7 @@ namespace dlib
} }
} }
} }
else if (alpha == -1) else if (alpha == static_cast<typename EXP2::type>(-1))
{ {
for (long r = 0; r < src.nr(); ++r) for (long r = 0; r < src.nr(); ++r)
{ {
...@@ -147,7 +147,7 @@ namespace dlib ...@@ -147,7 +147,7 @@ namespace dlib
} }
else else
{ {
if (alpha == 1) if (alpha == static_cast<typename EXP2::type>(1))
{ {
for (long r = 0; r < src.nr(); ++r) for (long r = 0; r < src.nr(); ++r)
{ {
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#endif #endif
#include "matrix_assign.h" #include "matrix_assign.h"
#include "matrix_conj_trans.h"
#include "cblas.h" #include "cblas.h"
...@@ -105,17 +106,17 @@ namespace dlib ...@@ -105,17 +106,17 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
inline void cblas_ger(const enum CBLAS_ORDER order, const int M, const int N, inline void cblas_ger(const enum CBLAS_ORDER order, const int M, const int N,
const std::complex<float> *alpha, const std::complex<float> *X, const int incX, const std::complex<float>& alpha, const std::complex<float> *X, const int incX,
const std::complex<float> *Y, const int incY, std::complex<float> *A, const int lda) const std::complex<float> *Y, const int incY, std::complex<float> *A, const int lda)
{ {
cblas_cgeru (order, M, N, alpha, X, incX, Y, incY, A, lda); cblas_cgeru (order, M, N, &alpha, X, incX, Y, incY, A, lda);
} }
inline void cblas_ger(const enum CBLAS_ORDER order, const int M, const int N, inline void cblas_ger(const enum CBLAS_ORDER order, const int M, const int N,
const std::complex<double> *alpha, const std::complex<double> *X, const int incX, const std::complex<double>& alpha, const std::complex<double> *X, const int incX,
const std::complex<double> *Y, const int incY, std::complex<double> *A, const int lda) const std::complex<double> *Y, const int incY, std::complex<double> *A, const int lda)
{ {
cblas_zgeru (order, M, N, alpha, X, incX, Y, incY, A, lda); cblas_zgeru (order, M, N, &alpha, X, incX, Y, incY, A, lda);
} }
inline void cblas_ger(const enum CBLAS_ORDER order, const int M, const int N, inline void cblas_ger(const enum CBLAS_ORDER order, const int M, const int N,
...@@ -135,17 +136,17 @@ namespace dlib ...@@ -135,17 +136,17 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
inline void cblas_gerc(const enum CBLAS_ORDER order, const int M, const int N, inline void cblas_gerc(const enum CBLAS_ORDER order, const int M, const int N,
const std::complex<float> *alpha, const std::complex<float> *X, const int incX, const std::complex<float>& alpha, const std::complex<float> *X, const int incX,
const std::complex<float> *Y, const int incY, std::complex<float> *A, const int lda) const std::complex<float> *Y, const int incY, std::complex<float> *A, const int lda)
{ {
cblas_cgerc (order, M, N, alpha, X, incX, Y, incY, A, lda); cblas_cgerc (order, M, N, &alpha, X, incX, Y, incY, A, lda);
} }
inline void cblas_gerc(const enum CBLAS_ORDER order, const int M, const int N, inline void cblas_gerc(const enum CBLAS_ORDER order, const int M, const int N,
const std::complex<double> *alpha, const std::complex<double> *X, const int incX, const std::complex<double>& alpha, const std::complex<double> *X, const int incX,
const std::complex<double> *Y, const int incY, std::complex<double> *A, const int lda) const std::complex<double> *Y, const int incY, std::complex<double> *A, const int lda)
{ {
cblas_zgerc (order, M, N, alpha, X, incX, Y, incY, A, lda); cblas_zgerc (order, M, N, &alpha, X, incX, Y, incY, A, lda);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -323,6 +324,7 @@ namespace dlib ...@@ -323,6 +324,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(m*m) DLIB_ADD_BLAS_BINDING(m*m)
{ {
//cout << "BLAS: m*m" << endl; //cout << "BLAS: m*m" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans; const CBLAS_TRANSPOSE TransB = CblasNoTrans;
...@@ -345,6 +347,7 @@ namespace dlib ...@@ -345,6 +347,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(trans(m)*m) DLIB_ADD_BLAS_BINDING(trans(m)*m)
{ {
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans; const CBLAS_TRANSPOSE TransB = CblasNoTrans;
...@@ -368,6 +371,7 @@ namespace dlib ...@@ -368,6 +371,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(m*trans(m)) DLIB_ADD_BLAS_BINDING(m*trans(m))
{ {
//cout << "BLAS: m*trans(m)" << endl; //cout << "BLAS: m*trans(m)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasTrans; const CBLAS_TRANSPOSE TransB = CblasTrans;
...@@ -390,6 +394,7 @@ namespace dlib ...@@ -390,6 +394,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(trans(m)*trans(m)) DLIB_ADD_BLAS_BINDING(trans(m)*trans(m))
{ {
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasTrans; const CBLAS_TRANSPOSE TransB = CblasTrans;
...@@ -417,6 +422,7 @@ namespace dlib ...@@ -417,6 +422,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(m*cv) DLIB_ADD_BLAS_BINDING(m*cv)
{ {
//cout << "BLAS: m*cv" << endl; //cout << "BLAS: m*cv" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const int M = static_cast<int>(src.lhs.nr()); const int M = static_cast<int>(src.lhs.nr());
...@@ -440,6 +446,7 @@ namespace dlib ...@@ -440,6 +446,7 @@ namespace dlib
// Note that rv*m is the same as trans(m)*trans(rv) // Note that rv*m is the same as trans(m)*trans(rv)
//cout << "BLAS: rv*m" << endl; //cout << "BLAS: rv*m" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
const int M = static_cast<int>(src.rhs.nr()); const int M = static_cast<int>(src.rhs.nr());
...@@ -463,6 +470,7 @@ namespace dlib ...@@ -463,6 +470,7 @@ namespace dlib
// Note that trans(cv)*m is the same as trans(m)*cv // Note that trans(cv)*m is the same as trans(m)*cv
//cout << "BLAS: trans(cv)*m" << endl; //cout << "BLAS: trans(cv)*m" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
const int M = static_cast<int>(src.rhs.nr()); const int M = static_cast<int>(src.rhs.nr());
...@@ -484,6 +492,7 @@ namespace dlib ...@@ -484,6 +492,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(m*trans(rv)) DLIB_ADD_BLAS_BINDING(m*trans(rv))
{ {
//cout << "BLAS: m*trans(rv)" << endl; //cout << "BLAS: m*trans(rv)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const int M = static_cast<int>(src.lhs.nr()); const int M = static_cast<int>(src.lhs.nr());
...@@ -507,6 +516,7 @@ namespace dlib ...@@ -507,6 +516,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(trans(m)*cv) DLIB_ADD_BLAS_BINDING(trans(m)*cv)
{ {
//cout << "BLAS: trans(m)*cv" << endl; //cout << "BLAS: trans(m)*cv" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
const int M = static_cast<int>(src.lhs.m.nr()); const int M = static_cast<int>(src.lhs.m.nr());
...@@ -530,6 +540,7 @@ namespace dlib ...@@ -530,6 +540,7 @@ namespace dlib
// Note that rv*trans(m) is the same as m*trans(rv) // Note that rv*trans(m) is the same as m*trans(rv)
//cout << "BLAS: rv*trans(m)" << endl; //cout << "BLAS: rv*trans(m)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const int M = static_cast<int>(src.rhs.m.nr()); const int M = static_cast<int>(src.rhs.m.nr());
...@@ -553,6 +564,7 @@ namespace dlib ...@@ -553,6 +564,7 @@ namespace dlib
// Note that trans(cv)*trans(m) is the same as m*cv // Note that trans(cv)*trans(m) is the same as m*cv
//cout << "BLAS: trans(cv)*trans(m)" << endl; //cout << "BLAS: trans(cv)*trans(m)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans; const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const int M = static_cast<int>(src.rhs.m.nr()); const int M = static_cast<int>(src.rhs.m.nr());
...@@ -574,6 +586,7 @@ namespace dlib ...@@ -574,6 +586,7 @@ namespace dlib
DLIB_ADD_BLAS_BINDING(trans(m)*trans(rv)) DLIB_ADD_BLAS_BINDING(trans(m)*trans(rv))
{ {
//cout << "BLAS: trans(m)*trans(rv)" << endl; //cout << "BLAS: trans(m)*trans(rv)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor; const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans; const CBLAS_TRANSPOSE TransA = CblasTrans;
const int M = static_cast<int>(src.lhs.m.nr()); const int M = static_cast<int>(src.lhs.m.nr());
...@@ -598,6 +611,95 @@ namespace dlib ...@@ -598,6 +611,95 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
DLIB_ADD_BLAS_BINDING(cv*rv)
{
//cout << "BLAS GER: cv*rv" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const int M = static_cast<int>(dest.nr());
const int N = static_cast<int>(dest.nc());
const T* X = get_ptr(src.lhs);
const int incX = get_inc(src.lhs);
const T* Y = get_ptr(src.rhs);
const int incY = get_inc(src.rhs);
if (add_to == false)
zero_matrix(dest);
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(rv)*rv)
{
//cout << "BLAS GER: trans(rv)*rv" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const int M = static_cast<int>(dest.nr());
const int N = static_cast<int>(dest.nc());
const T* X = get_ptr(src.lhs.m);
const int incX = get_inc(src.lhs.m);
const T* Y = get_ptr(src.rhs);
const int incY = get_inc(src.rhs);
if (add_to == false)
zero_matrix(dest);
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(cv*trans(cv))
{
//cout << "BLAS GER: cv*trans(cv)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const int M = static_cast<int>(dest.nr());
const int N = static_cast<int>(dest.nc());
const T* X = get_ptr(src.lhs);
const int incX = get_inc(src.lhs);
const T* Y = get_ptr(src.rhs.m);
const int incY = get_inc(src.rhs.m);
if (add_to == false)
zero_matrix(dest);
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(rv)*trans(cv))
{
//cout << "BLAS GER: trans(rv)*trans(cv)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const int M = static_cast<int>(dest.nr());
const int N = static_cast<int>(dest.nc());
const T* X = get_ptr(src.lhs.m);
const int incX = get_inc(src.lhs.m);
const T* Y = get_ptr(src.rhs.m);
const int incY = get_inc(src.rhs.m);
if (add_to == false)
zero_matrix(dest);
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_ger(Order, M, N, alpha, X, incX, Y, incY, A, lda);
} DLIB_END_BLAS_BINDING
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -605,6 +707,95 @@ namespace dlib ...@@ -605,6 +707,95 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
DLIB_ADD_BLAS_BINDING(cv*conj(rv))
{
//cout << "BLAS GERC: cv*conj(rv)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const int M = static_cast<int>(dest.nr());
const int N = static_cast<int>(dest.nc());
const T* X = get_ptr(src.lhs);
const int incX = get_inc(src.lhs);
const T* Y = get_ptr(src.rhs.m);
const int incY = get_inc(src.rhs.m);
if (add_to == false)
zero_matrix(dest);
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(cv*conj(trans(cv)))
{
//cout << "BLAS GERC: cv*conj(trans(cv))" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const int M = static_cast<int>(dest.nr());
const int N = static_cast<int>(dest.nc());
const T* X = get_ptr(src.lhs);
const int incX = get_inc(src.lhs);
const T* Y = get_ptr(src.rhs.m);
const int incY = get_inc(src.rhs.m);
if (add_to == false)
zero_matrix(dest);
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(rv)*conj(trans(cv)))
{
//cout << "BLAS GERC: trans(rv)*conj(trans(cv))" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const int M = static_cast<int>(dest.nr());
const int N = static_cast<int>(dest.nc());
const T* X = get_ptr(src.lhs.m);
const int incX = get_inc(src.lhs.m);
const T* Y = get_ptr(src.rhs.m);
const int incY = get_inc(src.rhs.m);
if (add_to == false)
zero_matrix(dest);
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(rv)*conj(rv))
{
//cout << "BLAS GERC: trans(rv)*conj(rv)" << endl;
const bool is_row_major_order = is_same_type<typename dest_exp::layout_type,row_major_layout>::value;
const CBLAS_ORDER Order = is_row_major_order ? CblasRowMajor : CblasColMajor;
const int M = static_cast<int>(dest.nr());
const int N = static_cast<int>(dest.nc());
const T* X = get_ptr(src.lhs.m);
const int incX = get_inc(src.lhs.m);
const T* Y = get_ptr(src.rhs.m);
const int incY = get_inc(src.rhs.m);
if (add_to == false)
zero_matrix(dest);
T* A = get_ptr(dest);
const int lda = get_ld(dest);
cblas_gerc(Order, M, N, alpha, X, incX, Y, incY, A, lda);
} DLIB_END_BLAS_BINDING
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -612,6 +803,75 @@ namespace dlib ...@@ -612,6 +803,75 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
DLIB_ADD_BLAS_BINDING(rv*cv)
{
//cout << "BLAS DOT: rv*cv" << endl;
const int N = static_cast<int>(src.lhs.size());
const T* X = get_ptr(src.lhs);
const int incX = get_inc(src.lhs);
const T* Y = get_ptr(src.rhs);
const int incY = get_inc(src.rhs);
if (add_to == false)
dest(0) = alpha*cblas_dot(N, X, incX, Y, incY);
else
dest(0) += alpha*cblas_dot(N, X, incX, Y, incY);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(cv)*cv)
{
//cout << "BLAS DOT: trans(cv)*cv" << endl;
const int N = static_cast<int>(src.lhs.size());
const T* X = get_ptr(src.lhs.m);
const int incX = get_inc(src.lhs.m);
const T* Y = get_ptr(src.rhs);
const int incY = get_inc(src.rhs);
if (add_to == false)
dest(0) = alpha*cblas_dot(N, X, incX, Y, incY);
else
dest(0) += alpha*cblas_dot(N, X, incX, Y, incY);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(rv*trans(rv))
{
//cout << "BLAS DOT: rv*trans(rv)" << endl;
const int N = static_cast<int>(src.lhs.size());
const T* X = get_ptr(src.lhs);
const int incX = get_inc(src.lhs);
const T* Y = get_ptr(src.rhs.m);
const int incY = get_inc(src.rhs.m);
if (add_to == false)
dest(0) = alpha*cblas_dot(N, X, incX, Y, incY);
else
dest(0) += alpha*cblas_dot(N, X, incX, Y, incY);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(cv)*trans(rv))
{
//cout << "BLAS DOT: trans(cv)*trans(rv)" << endl;
const int N = static_cast<int>(src.lhs.m.size());
const T* X = get_ptr(src.lhs.m);
const int incX = get_inc(src.lhs.m);
const T* Y = get_ptr(src.rhs.m);
const int incY = get_inc(src.rhs.m);
if (add_to == false)
dest(0) = alpha*cblas_dot(N, X, incX, Y, incY);
else
dest(0) += alpha*cblas_dot(N, X, incX, Y, incY);
} DLIB_END_BLAS_BINDING
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -619,9 +879,57 @@ namespace dlib ...@@ -619,9 +879,57 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
DLIB_ADD_BLAS_BINDING(conj(rv)*cv)
{
//cout << "BLAS DOTC: conj(rv)*cv" << endl;
const int N = static_cast<int>(src.lhs.m.size());
const T* X = get_ptr(src.lhs.m);
const int incX = get_inc(src.lhs.m);
const T* Y = get_ptr(src.rhs);
const int incY = get_inc(src.rhs);
if (add_to == false)
dest(0) = alpha*cblas_dotc(N, X, incX, Y, incY);
else
dest(0) += alpha*cblas_dotc(N, X, incX, Y, incY);
// ---------------------------------------------------------------------------------------- } DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(conj(trans(cv))*cv)
{
//cout << "BLAS DOTC: conj(trans(cv))*cv" << endl;
const int N = static_cast<int>(src.lhs.m.size());
const T* X = get_ptr(src.lhs.m);
const int incX = get_inc(src.lhs.m);
const T* Y = get_ptr(src.rhs);
const int incY = get_inc(src.rhs);
if (add_to == false)
dest(0) = alpha*cblas_dotc(N, X, incX, Y, incY);
else
dest(0) += alpha*cblas_dotc(N, X, incX, Y, incY);
} DLIB_END_BLAS_BINDING
// --------------------------------------
DLIB_ADD_BLAS_BINDING(trans(conj(cv))*trans(rv))
{
//cout << "BLAS DOTC: trans(conj(cv))*trans(rv)" << endl;
const int N = static_cast<int>(src.lhs.m.size());
const T* X = get_ptr(src.lhs.m);
const int incX = get_inc(src.lhs.m);
const T* Y = get_ptr(src.rhs.m);
const int incY = get_inc(src.rhs.m);
if (add_to == false)
dest(0) = alpha*cblas_dotc(N, X, incX, Y, incY);
else
dest(0) += alpha*cblas_dotc(N, X, incX, Y, incY);
} DLIB_END_BLAS_BINDING
} }
......
// Copyright (C) 2009 Davis E. King (davisking@users.sourceforge.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_MATRIx_CONJ_TRANS_FUNCTIONS
#define DLIB_MATRIx_CONJ_TRANS_FUNCTIONS
#include "matrix_utilities.h"
#include "matrix_math_functions.h"
#include "matrix.h"
#include "../algs.h"
#include <cmath>
#include <complex>
#include <limits>
namespace dlib
{
/*!
The point of the two functions defined in this file is to make statements
of the form conj(trans(m)) and trans(conj(m)) look the same so that it is
easier to map them to BLAS functions later on.
!*/
// ----------------------------------------------------------------------------------------
struct op_conj_trans
{
template <typename EXP>
struct op : has_destructive_aliasing
{
const static long cost = EXP::cost;
const static long NR = EXP::NC;
const static long NC = EXP::NR;
typedef typename EXP::type type;
typedef typename EXP::mem_manager_type mem_manager_type;
template <typename M>
static type apply ( const M& m, long r, long c)
{ return std::conj(m(c,r)); }
template <typename M>
static long nr (const M& m) { return m.nc(); }
template <typename M>
static long nc (const M& m) { return m.nr(); }
};
};
template <typename EXP>
const matrix_unary_exp<EXP,op_conj_trans> trans (
const matrix_unary_exp<EXP,op_conj>& m
)
{
typedef matrix_unary_exp<EXP,op_conj_trans> exp;
return exp(m.m);
}
template <typename EXP>
const matrix_unary_exp<EXP,op_conj_trans> conj (
const matrix_unary_exp<EXP,op_trans>& m
)
{
typedef matrix_unary_exp<EXP,op_conj_trans> exp;
return exp(m.m);
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_MATRIx_CONJ_TRANS_FUNCTIONS
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment