Commit 0920c8f8 authored by Davis King's avatar Davis King

Reorganized the matrix_assign code a little and also added in

some stuff to allow me to bind whatever expressions I feel like
to optimized BLAS libraries.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402752
parent 7baab9fa
......@@ -8,6 +8,7 @@
#include "matrix/matrix_subexp.h"
#include "matrix/matrix_math_functions.h"
#include "matrix/matrix_assign.h"
#include "matrix/matrix_blas_bindings.h"
#endif // DLIB_MATRIx_HEADER
......
This diff is collapsed.
......@@ -7,11 +7,22 @@
namespace dlib
{
/*
The point of the matrix_assign() functions is to contain all the various
optimizations that help the matrix assign a matrix_exp to an actual matrix
object quickly.
*/
// ----------------------------------------------------------------------------------------
namespace ma
{
// This template here controls how big a compile time sized matrix needs
// to be for it to get passed into the optimized versions of the
// matrix_assign() function. So small matrices are evaluated with a simple
// loop like the ones in this file and bigger matrices may get sent to BLAS
// routines or some other kind of optimized thing.
template < typename EXP, typename enable = void >
struct is_small_matrix { static const bool value = false; };
template < typename EXP >
......
// Copyright (C) 2008 Davis E. King (davisking@users.sourceforge.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_MATRIx_BLAS_BINDINGS_
#define DLIB_MATRIx_BLAS_BINDINGS_
#include "matrix_assign.h"
#ifdef DLIB_FOUND_BLAS
#include "mkl_cblas.h"
#endif
namespace dlib
{
namespace blas_bindings
{
// ----------------------------------------------------------------------------------------
typedef memory_manager<char>::kernel_1a mm;
extern matrix<double,0,0,mm,row_major_layout> dm;
extern matrix<float,0,0,mm,row_major_layout> sm;
extern matrix<double,1,0,mm,row_major_layout> drv;
extern matrix<double,0,1,mm,row_major_layout> dcv;
extern matrix<float,1,0,mm,row_major_layout> srv;
extern matrix<float,0,1,mm,row_major_layout> scv;
using namespace std;
#ifdef DLIB_FOUND_BLAS
DLIB_ADD_BLAS_BINDING(double, row_major_layout, dm*dm)
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const double alpha = 1;
const double* A = &src.lhs(0,0);
const int lda = src.lhs.nc();
const double* B = &src.rhs(0,0);
const int ldb = src.rhs.nc();
const double beta = 0;
double* C = &dest(0,0);
const int ldc = src.nc();
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}};
DLIB_ADD_BLAS_BINDING(double, row_major_layout, trans(dm)*dm)
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const double alpha = 1;
const double* A = &src.lhs.m(0,0);
const int lda = src.lhs.m.nc();
const double* B = &src.rhs(0,0);
const int ldb = src.rhs.nc();
const double beta = 0;
double* C = &dest(0,0);
const int ldc = src.nc();
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}};
// double overloads
DLIB_ADD_BLAS_BINDING(float, row_major_layout, sm*sm)
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const float alpha = 1;
const float* A = &src.lhs(0,0);
const int lda = src.lhs.nc();
const float* B = &src.rhs(0,0);
const int ldb = src.rhs.nc();
const float beta = 0;
float* C = &dest(0,0);
const int ldc = src.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}};
DLIB_ADD_BLAS_BINDING(float, row_major_layout, trans(sm)*sm)
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const float alpha = 1;
const float* A = &src.lhs.m(0,0);
const int lda = src.lhs.m.nc();
const float* B = &src.rhs(0,0);
const int ldb = src.rhs.nc();
const float beta = 0;
float* C = &dest(0,0);
const int ldc = src.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}};
#endif // DLIB_FOUND_BLAS
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_MATRIx_BLAS_BINDINGS_
// Copyright (C) 2008 Davis E. King (davisking@users.sourceforge.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_MATRIx_DEFAULT_MULTIPLY_
#define DLIB_MATRIx_DEFAULT_MULTIPLY_
#include "../geometry.h"
#include "matrix.h"
#include "matrix_utilities.h"
#include "../enable_if.h"
namespace dlib
{
// ------------------------------------------------------------------------------------
namespace ma
{
template < typename EXP, typename enable = void >
struct matrix_is_vector { static const bool value = false; };
template < typename EXP >
struct matrix_is_vector<EXP, typename enable_if_c<EXP::NR==1 || EXP::NC==1>::type > { static const bool value = true; };
}
// ------------------------------------------------------------------------------------
template <
typename matrix_dest_type,
typename EXP1,
typename EXP2
>
typename enable_if_c<ma::matrix_is_vector<EXP1>::value == true && ma::matrix_is_vector<EXP2>::value == true>::type
default_matrix_multiply (
matrix_dest_type& dest,
const EXP1& lhs,
const EXP2& rhs
);
/*!
requires
- (lhs*rhs).destructively_aliases(dest) == false
- dest.nr() == (lhs*rhs).nr()
- dest.nc() == (lhs*rhs).nc()
ensures
- #dest == dest + lhs*rhs
!*/
// ------------------------------------------------------------------------------------
template <
typename matrix_dest_type,
typename EXP1,
typename EXP2
>
typename enable_if_c<ma::matrix_is_vector<EXP1>::value == true || ma::matrix_is_vector<EXP2>::value == true>::type
default_matrix_multiply (
matrix_dest_type& dest,
const EXP1& lhs,
const EXP2& rhs
)
{
// This loop is optimized assuming that the data is laid out in
// row major order in memory.
for (long r = 0; r< lhs.nr(); ++r)
{
for (long c = 0; c< lhs.nc(); ++c)
{
const typename EXP2::type temp = lhs(r,c);
for (long i = 0; i < rhs.nc(); ++i)
{
dest(r,i) += rhs(c,i)*temp;
}
}
}
}
// ------------------------------------------------------------------------------------
template <
typename matrix_dest_type,
typename EXP1,
typename EXP2
>
typename enable_if_c<ma::matrix_is_vector<EXP1>::value == false && ma::matrix_is_vector<EXP2>::value == false>::type
default_matrix_multiply (
matrix_dest_type& dest,
const EXP1& lhs,
const EXP2& rhs
)
{
const long bs = 90;
// if the matrices are small enough then just use the simple multiply algorithm
if (lhs.nc() <= 2 || rhs.nc() <= 2 || lhs.nr() <= 2 || rhs.nr() <= 2 || (lhs.size() <= bs*10 && rhs.size() <= bs*10) )
{
// This loop is optimized assuming that the data is laid out in
// row major order in memory.
for (long r = 0; r< lhs.nr(); ++r)
{
for (long c = 0; c< lhs.nc(); ++c)
{
const typename EXP2::type temp = lhs(r,c);
for (long i = 0; i < rhs.nc(); ++i)
{
dest(r,i) += rhs(c,i)*temp;
}
}
}
}
else
{
// if the lhs and rhs matrices are big enough we should use a cache friendly
// algorithm that computes the matrix multiply in blocks.
// Loop over all the blocks in the lhs matrix
for (long r = 0; r < lhs.nr(); r+=bs)
{
for (long c = 0; c < lhs.nc(); c+=bs)
{
// make a rect for the block from lhs
rectangle lhs_block(c, r, std::min(c+bs-1,lhs.nc()-1), std::min(r+bs-1,lhs.nr()-1));
// now loop over all the rhs blocks we have to multiply with the current lhs block
for (long i = 0; i < rhs.nc(); i += bs)
{
// make a rect for the block from rhs
rectangle rhs_block(i, c, std::min(i+bs-1,rhs.nc()-1), std::min(c+bs-1,rhs.nr()-1));
// make a target rect in res
rectangle res_block(rhs_block.left(),lhs_block.top(), rhs_block.right(), lhs_block.bottom());
// This loop is optimized assuming that the data is laid out in
// row major order in memory.
for (long r = lhs_block.top(); r <= lhs_block.bottom(); ++r)
{
for (long c = lhs_block.left(); c<= lhs_block.right(); ++c)
{
const typename EXP2::type temp = lhs(r,c);
for (long i = rhs_block.left(); i <= rhs_block.right(); ++i)
{
dest(r,i) += rhs(c,i)*temp;
}
}
}
}
}
}
}
}
// ------------------------------------------------------------------------------------
}
#endif // DLIB_MATRIx_DEFAULT_MULTIPLY_
......@@ -285,7 +285,6 @@ namespace dlib
long nc (
) const { return OP::nc(m); }
private:
const M& m;
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment