Commit 0920c8f8 authored by Davis King's avatar Davis King

Reorganized the matrix_assign code a little and also added in

some stuff to allow me to bind whatever expressions I feel like
to optimized BLAS libraries.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402752
parent 7baab9fa
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "matrix/matrix_subexp.h" #include "matrix/matrix_subexp.h"
#include "matrix/matrix_math_functions.h" #include "matrix/matrix_math_functions.h"
#include "matrix/matrix_assign.h" #include "matrix/matrix_assign.h"
#include "matrix/matrix_blas_bindings.h"
#endif // DLIB_MATRIx_HEADER #endif // DLIB_MATRIx_HEADER
......
...@@ -8,147 +8,169 @@ ...@@ -8,147 +8,169 @@
#include "matrix_utilities.h" #include "matrix_utilities.h"
#include "../enable_if.h" #include "../enable_if.h"
#include "matrix_assign_fwd.h" #include "matrix_assign_fwd.h"
#include "matrix_default_mul.h"
namespace dlib namespace dlib
{ {
/* /*
This file contains some templates that are used inside the matrix_blas_bindings.h
This file is where all the implementations of matrix_assign() live. The point of the file to bind various matrix expressions to optimized code for carrying them out.
matrix_assign() functions is to contain all the various optimizations that help the
matrix assign a matrix_exp to an actual matrix object quickly.
*/ */
namespace ma // ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
namespace blas_bindings
{ {
// This namespace defines whatever helpers we need in the rest of this file.
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
struct op_null // This template struct is used to tell is if two matrix types both contain the same type of
{ // element, have the same layout, and are both general matrices or both vectors.
template <typename EXP> template <typename T, typename U>
struct op : has_nondestructive_aliasing, preserves_dimensions<EXP> struct same_matrix
{
const static long cost = EXP::cost;
typedef typename EXP::type type;
template <typename M>
static type apply ( const M& m, long r, long c)
{ {
return m(r,c); const static bool value = false;
}
};
}; };
template < template <typename T, typename L, long a, long b, long c, long d, typename MM1, typename MM2 >
typename EXP struct same_matrix <matrix<T,a,b,MM1,L>, matrix<T,c,d,MM2,L> >
>
const matrix_unary_exp<EXP,op_null> null_exp (
const matrix_exp<EXP>& m
)
/*!
The only point of this function is to make it easy to cause the overloads
of matrix_assign to not trigger for a matrix expression.
!*/
{ {
return matrix_unary_exp<EXP,op_null>(m.ref()); /*! These two matrices are the same if they are either:
} - both row vectors
- both column vectors
- both not any kind of vector
!*/
const static bool value = (a == 1 && c == 1) || (b==1 && d==1) || (a!=1 && b!=1 && c!=1 && d!=1) ;
};
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
template < typename EXP, typename enable = void > // This template struct is used to tell us if two matrix expressions both contain the same
struct matrix_is_vector { static const bool value = false; }; // sequence of operators, expressions, and work on matrices laid out in memory in compatible ways.
template < typename EXP > template <typename T, typename U>
struct matrix_is_vector<EXP, typename enable_if_c<EXP::NR==1 || EXP::NC==1>::type > { static const bool value = true; }; struct same_exp
{
const static bool value = is_same_type<typename T::exp_type, typename U::exp_type>::value ||
same_matrix<typename T::exp_type, typename U::exp_type>::value;;
};
} template <typename Tlhs, typename Ulhs, typename Trhs, typename Urhs>
struct same_exp<matrix_multiply_exp<Tlhs,Trhs>, matrix_multiply_exp<Ulhs,Urhs> >
{ const static bool value = same_exp<Tlhs,Ulhs>::value && same_exp<Trhs,Urhs>::value; };
// ---------------------------------------------------------------------------------------- template <typename Tlhs, typename Ulhs, typename Trhs, typename Urhs>
struct same_exp<matrix_add_exp<Tlhs,Trhs>, matrix_add_exp<Ulhs,Urhs> >
{ const static bool value = same_exp<Tlhs,Ulhs>::value && same_exp<Trhs,Urhs>::value; };
template < template <typename Tlhs, typename Ulhs, typename Trhs, typename Urhs>
typename matrix_dest_type, struct same_exp<matrix_subtract_exp<Tlhs,Trhs>, matrix_subtract_exp<Ulhs,Urhs> >
typename EXP1, { const static bool value = same_exp<Tlhs,Ulhs>::value && same_exp<Trhs,Urhs>::value; };
typename EXP2
>
typename disable_if_c<ma::matrix_is_vector<EXP1>::value || ma::matrix_is_vector<EXP2>::value>::type
matrix_assign_big (
matrix_dest_type& dest,
const matrix_multiply_exp<EXP1,EXP2>& src
)
/*!
This overload catches assignments like:
dest = lhs*rhs
where lhs and rhs are both not vectors
!*/
{
using namespace ma;
const EXP1& lhs = src.lhs;
const EXP2& rhs = src.rhs;
const long bs = 90;
set_all_elements(dest,0);
// if the matrices are small enough then just use the simple multiply algorithm template <typename T, typename U> struct same_exp<matrix_mul_scal_exp<T>, matrix_mul_scal_exp<U> >
if (lhs.nc() <= 2 || rhs.nc() <= 2 || lhs.nr() <= 2 || rhs.nr() <= 2 || (lhs.size() <= bs*10 && rhs.size() <= bs*10) ) { const static bool value = same_exp<T,U>::value; };
{
// This loop is optimized assuming that the data is laid out in
// row major order in memory.
for (long r = 0; r< lhs.nr(); ++r)
{
for (long c = 0; c< lhs.nc(); ++c)
{
const typename EXP2::type temp = lhs(r,c);
for (long i = 0; i < rhs.nc(); ++i)
{
dest(r,i) += rhs(c,i)*temp;
}
}
}
}
else
{
// if the lhs and rhs matrices are big enough we should use a cache friendly
// algorithm that computes the matrix multiply in blocks.
template <typename T, typename U> struct same_exp<matrix_div_scal_exp<T>, matrix_div_scal_exp<U> >
{ const static bool value = same_exp<T,U>::value; };
// Loop over all the blocks in the lhs matrix template <typename T, typename U, typename OP> struct same_exp<matrix_unary_exp<T,OP>, matrix_unary_exp<U,OP> >
for (long r = 0; r < lhs.nr(); r+=bs) { const static bool value = same_exp<T,U>::value; };
// ------------------------------------------------------------------------------------
struct yes_type
{ {
for (long c = 0; c < lhs.nc(); c+=bs) char ch;
};
struct no_type
{ {
// make a rect for the block from lhs yes_type a, b;
rectangle lhs_block(c, r, std::min(c+bs-1,lhs.nc()-1), std::min(r+bs-1,lhs.nr()-1)); };
// now loop over all the rhs blocks we have to multiply with the current lhs block // This is a helper that is used below to apply the same_exp template to matrix expressions.
for (long i = 0; i < rhs.nc(); i += bs) template <typename T, typename U>
{ typename enable_if<same_exp<T,U>,yes_type>::type test(U);
// make a rect for the block from rhs template <typename T, typename U>
rectangle rhs_block(i, c, std::min(i+bs-1,rhs.nc()-1), std::min(c+bs-1,rhs.nr()-1)); typename disable_if<same_exp<T,U>,no_type>::type test(U);
// make a target rect in res // ------------------------------------------------------------------------------------
rectangle res_block(rhs_block.left(),lhs_block.top(), rhs_block.right(), lhs_block.bottom());
// This loop is optimized assuming that the data is laid out in template <
// row major order in memory. typename T, long NR, long NC, typename MM, typename L,
for (long r = lhs_block.top(); r <= lhs_block.bottom(); ++r) typename src_exp,
typename enabled = void
>
struct matrix_assign_blas_helper
{ {
for (long c = lhs_block.left(); c<= lhs_block.right(); ++c) // We are in the default version of the blas helper so this
// means there wasn't any more specific overload. So just
// let the default matrix assignment happen.
template <typename EXP>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const EXP& src
)
{ {
const typename EXP2::type temp = lhs(r,c); for (long r = 0; r < src.nr(); ++r)
for (long i = rhs_block.left(); i <= rhs_block.right(); ++i)
{ {
dest(r,i) += rhs(c,i)*temp; for (long c = 0; c < src.nc(); ++c)
} {
} dest(r,c) = src(r,c);
}
} }
} }
} }
// If we know this is a matrix multiply then apply the
// default dlib matrix multiply to speed things up a bit more
// than the above default function would.
template <typename EXP1, typename EXP2>
static void assign (
matrix<T,NR,NC,MM,L>& dest,
const matrix_multiply_exp<EXP1,EXP2>& src
)
{
set_all_elements(dest,0);
default_matrix_multiply(dest, src.lhs, src.rhs);
} }
};
// This is a macro to help us add overloads for the matrix_assign_blas_helper template.
// Using this macro it is easy to add overloads for arbitrary matrix expressions.
#define DLIB_ADD_BLAS_BINDING( dest_type, dest_layout, src_expression) \
template <typename T> struct BOOST_JOIN(blas,__LINE__) \
{ const static bool value = sizeof(yes_type) == sizeof(test<T>(src_expression)); }; \
template < long NR, long NC, typename MM, typename src_exp > \
struct matrix_assign_blas_helper<dest_type,NR,NC,MM,dest_layout, src_exp, \
typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp> >::type > { \
static void assign ( \
matrix<dest_type,NR,NC,MM,dest_layout>& dest, \
const src_exp& src \
) {
// ------------------------------------------------------------------------------------
} // end of namespace blas_bindings
// ------------------------------------------------------------------------------------
template <
typename T, long NR, long NC, typename MM, typename L,
typename src_exp
>
inline void matrix_assign_big (
matrix<T,NR,NC,MM,L>& dest,
const src_exp& src
)
{
blas_bindings::matrix_assign_blas_helper<T,NR,NC,MM,L,src_exp>::assign(dest,src);
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
} }
......
...@@ -8,10 +8,21 @@ ...@@ -8,10 +8,21 @@
namespace dlib namespace dlib
{ {
/*
The point of the matrix_assign() functions is to contain all the various
optimizations that help the matrix assign a matrix_exp to an actual matrix
object quickly.
*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
namespace ma namespace ma
{ {
// This template here controls how big a compile time sized matrix needs
// to be for it to get passed into the optimized versions of the
// matrix_assign() function. So small matrices are evaluated with a simple
// loop like the ones in this file and bigger matrices may get sent to BLAS
// routines or some other kind of optimized thing.
template < typename EXP, typename enable = void > template < typename EXP, typename enable = void >
struct is_small_matrix { static const bool value = false; }; struct is_small_matrix { static const bool value = false; };
template < typename EXP > template < typename EXP >
......
// Copyright (C) 2008 Davis E. King (davisking@users.sourceforge.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_MATRIx_BLAS_BINDINGS_
#define DLIB_MATRIx_BLAS_BINDINGS_
#include "matrix_assign.h"
#ifdef DLIB_FOUND_BLAS
#include "mkl_cblas.h"
#endif
namespace dlib
{
namespace blas_bindings
{
// ----------------------------------------------------------------------------------------
typedef memory_manager<char>::kernel_1a mm;
extern matrix<double,0,0,mm,row_major_layout> dm;
extern matrix<float,0,0,mm,row_major_layout> sm;
extern matrix<double,1,0,mm,row_major_layout> drv;
extern matrix<double,0,1,mm,row_major_layout> dcv;
extern matrix<float,1,0,mm,row_major_layout> srv;
extern matrix<float,0,1,mm,row_major_layout> scv;
using namespace std;
#ifdef DLIB_FOUND_BLAS
DLIB_ADD_BLAS_BINDING(double, row_major_layout, dm*dm)
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const double alpha = 1;
const double* A = &src.lhs(0,0);
const int lda = src.lhs.nc();
const double* B = &src.rhs(0,0);
const int ldb = src.rhs.nc();
const double beta = 0;
double* C = &dest(0,0);
const int ldc = src.nc();
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}};
DLIB_ADD_BLAS_BINDING(double, row_major_layout, trans(dm)*dm)
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const double alpha = 1;
const double* A = &src.lhs.m(0,0);
const int lda = src.lhs.m.nc();
const double* B = &src.rhs(0,0);
const int ldb = src.rhs.nc();
const double beta = 0;
double* C = &dest(0,0);
const int ldc = src.nc();
cblas_dgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}};
// double overloads
DLIB_ADD_BLAS_BINDING(float, row_major_layout, sm*sm)
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasNoTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const float alpha = 1;
const float* A = &src.lhs(0,0);
const int lda = src.lhs.nc();
const float* B = &src.rhs(0,0);
const int ldb = src.rhs.nc();
const float beta = 0;
float* C = &dest(0,0);
const int ldc = src.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}};
DLIB_ADD_BLAS_BINDING(float, row_major_layout, trans(sm)*sm)
const CBLAS_ORDER Order = CblasRowMajor;
const CBLAS_TRANSPOSE TransA = CblasTrans;
const CBLAS_TRANSPOSE TransB = CblasNoTrans;
const int M = static_cast<int>(src.nr());
const int N = static_cast<int>(src.nc());
const int K = static_cast<int>(src.lhs.nc());
const float alpha = 1;
const float* A = &src.lhs.m(0,0);
const int lda = src.lhs.m.nc();
const float* B = &src.rhs(0,0);
const int ldb = src.rhs.nc();
const float beta = 0;
float* C = &dest(0,0);
const int ldc = src.nc();
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}};
#endif // DLIB_FOUND_BLAS
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_MATRIx_BLAS_BINDINGS_
// Copyright (C) 2008 Davis E. King (davisking@users.sourceforge.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_MATRIx_DEFAULT_MULTIPLY_
#define DLIB_MATRIx_DEFAULT_MULTIPLY_
#include "../geometry.h"
#include "matrix.h"
#include "matrix_utilities.h"
#include "../enable_if.h"
namespace dlib
{
// ------------------------------------------------------------------------------------
namespace ma
{
template < typename EXP, typename enable = void >
struct matrix_is_vector { static const bool value = false; };
template < typename EXP >
struct matrix_is_vector<EXP, typename enable_if_c<EXP::NR==1 || EXP::NC==1>::type > { static const bool value = true; };
}
// ------------------------------------------------------------------------------------
template <
typename matrix_dest_type,
typename EXP1,
typename EXP2
>
typename enable_if_c<ma::matrix_is_vector<EXP1>::value == true && ma::matrix_is_vector<EXP2>::value == true>::type
default_matrix_multiply (
matrix_dest_type& dest,
const EXP1& lhs,
const EXP2& rhs
);
/*!
requires
- (lhs*rhs).destructively_aliases(dest) == false
- dest.nr() == (lhs*rhs).nr()
- dest.nc() == (lhs*rhs).nc()
ensures
- #dest == dest + lhs*rhs
!*/
// ------------------------------------------------------------------------------------
template <
typename matrix_dest_type,
typename EXP1,
typename EXP2
>
typename enable_if_c<ma::matrix_is_vector<EXP1>::value == true || ma::matrix_is_vector<EXP2>::value == true>::type
default_matrix_multiply (
matrix_dest_type& dest,
const EXP1& lhs,
const EXP2& rhs
)
{
// This loop is optimized assuming that the data is laid out in
// row major order in memory.
for (long r = 0; r< lhs.nr(); ++r)
{
for (long c = 0; c< lhs.nc(); ++c)
{
const typename EXP2::type temp = lhs(r,c);
for (long i = 0; i < rhs.nc(); ++i)
{
dest(r,i) += rhs(c,i)*temp;
}
}
}
}
// ------------------------------------------------------------------------------------
template <
typename matrix_dest_type,
typename EXP1,
typename EXP2
>
typename enable_if_c<ma::matrix_is_vector<EXP1>::value == false && ma::matrix_is_vector<EXP2>::value == false>::type
default_matrix_multiply (
matrix_dest_type& dest,
const EXP1& lhs,
const EXP2& rhs
)
{
const long bs = 90;
// if the matrices are small enough then just use the simple multiply algorithm
if (lhs.nc() <= 2 || rhs.nc() <= 2 || lhs.nr() <= 2 || rhs.nr() <= 2 || (lhs.size() <= bs*10 && rhs.size() <= bs*10) )
{
// This loop is optimized assuming that the data is laid out in
// row major order in memory.
for (long r = 0; r< lhs.nr(); ++r)
{
for (long c = 0; c< lhs.nc(); ++c)
{
const typename EXP2::type temp = lhs(r,c);
for (long i = 0; i < rhs.nc(); ++i)
{
dest(r,i) += rhs(c,i)*temp;
}
}
}
}
else
{
// if the lhs and rhs matrices are big enough we should use a cache friendly
// algorithm that computes the matrix multiply in blocks.
// Loop over all the blocks in the lhs matrix
for (long r = 0; r < lhs.nr(); r+=bs)
{
for (long c = 0; c < lhs.nc(); c+=bs)
{
// make a rect for the block from lhs
rectangle lhs_block(c, r, std::min(c+bs-1,lhs.nc()-1), std::min(r+bs-1,lhs.nr()-1));
// now loop over all the rhs blocks we have to multiply with the current lhs block
for (long i = 0; i < rhs.nc(); i += bs)
{
// make a rect for the block from rhs
rectangle rhs_block(i, c, std::min(i+bs-1,rhs.nc()-1), std::min(c+bs-1,rhs.nr()-1));
// make a target rect in res
rectangle res_block(rhs_block.left(),lhs_block.top(), rhs_block.right(), lhs_block.bottom());
// This loop is optimized assuming that the data is laid out in
// row major order in memory.
for (long r = lhs_block.top(); r <= lhs_block.bottom(); ++r)
{
for (long c = lhs_block.left(); c<= lhs_block.right(); ++c)
{
const typename EXP2::type temp = lhs(r,c);
for (long i = rhs_block.left(); i <= rhs_block.right(); ++i)
{
dest(r,i) += rhs(c,i)*temp;
}
}
}
}
}
}
}
}
// ------------------------------------------------------------------------------------
}
#endif // DLIB_MATRIx_DEFAULT_MULTIPLY_
...@@ -285,7 +285,6 @@ namespace dlib ...@@ -285,7 +285,6 @@ namespace dlib
long nc ( long nc (
) const { return OP::nc(m); } ) const { return OP::nc(m); }
private:
const M& m; const M& m;
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment