Commit e54bfbb9 authored by Davis King's avatar Davis King

Optimized matrix multiplication a little

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402474
parent cbafa9c5
......@@ -6,6 +6,7 @@
#include "matrix/matrix.h"
#include "matrix/matrix_utilities.h"
#include "matrix/matrix_math_functions.h"
#include "matrix/matrix_assign.h"
#endif // DLIB_MATRIx_HEADER
......
......@@ -202,25 +202,6 @@ namespace dlib
{
}
void consume(
matrix_data& item
)
/*!
ensures
- #*this == item
- #item is in an untouchable state. no one should do anything
to it other than let it destruct.
!*/
{
for (long r = 0; r < num_rows; ++r)
{
for (long c = 0; c < num_cols; ++c)
{
(*this)(r,c) = item(r,c);
}
}
}
private:
T data[num_rows][num_cols];
};
......@@ -284,22 +265,6 @@ namespace dlib
{
}
void consume(
matrix_data& item
)
/*!
ensures
- #*this == item
- #item is in an untouchable state. no one should do anything
to it other than let it destruct.
!*/
{
pool.deallocate_array(data);
data = item.data;
item.data = 0;
pool.swap(item.pool);
}
private:
T* data;
......@@ -375,23 +340,6 @@ namespace dlib
nr_ = nr;
}
void consume(
matrix_data& item
)
/*!
ensures
- #*this == item
- #item is in an untouchable state. no one should do anything
to it other than let it destruct.
!*/
{
pool.deallocate_array(data);
data = item.data;
nr_ = item.nr_;
item.data = 0;
pool.swap(item.pool);
}
private:
T* data;
......@@ -470,23 +418,6 @@ namespace dlib
nc_ = nc;
}
void consume(
matrix_data& item
)
/*!
ensures
- #*this == item
- #item is in an untouchable state. no one should do anything
to it other than let it destruct.
!*/
{
pool.deallocate_array(data);
data = item.data;
nc_ = item.nc_;
item.data = 0;
pool.swap(item.pool);
}
private:
T* data;
......@@ -567,24 +498,6 @@ namespace dlib
nc_ = nc;
}
void consume(
matrix_data& item
)
/*!
ensures
- #*this == item
- #item is in an untouchable state. no one should do anything
to it other than let it destruct.
!*/
{
pool.deallocate_array(data);
data = item.data;
nc_ = item.nc_;
nr_ = item.nr_;
item.data = 0;
pool.swap(item.pool);
}
private:
T* data;
long nr_;
......@@ -1382,34 +1295,34 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename matrix_dest_type,
typename src_exp
>
void matrix_assign (
matrix_dest_type& dest,
const matrix_exp<src_exp>& src,
const long row_offset = 0,
const long col_offset = 0
)
/*!
requires
- src.destructively_aliases(dest) == false
- dest.nr() == src.nr()-row_offset
- dest.nc() == src.nc()-col_offset
ensures
- #subm(dest, row_offset, col_offset, src.nr(), src.nc()) == src
- the part of dest outside the above sub matrix remains unchanged
!*/
{
for (long r = 0; r < src.nr(); ++r)
template <
typename matrix_dest_type,
typename src_exp
>
void matrix_assign (
matrix_dest_type& dest,
const matrix_exp<src_exp>& src,
const long row_offset = 0,
const long col_offset = 0
)
/*!
requires
- src.destructively_aliases(dest) == false
- dest.nr() == src.nr()-row_offset
- dest.nc() == src.nc()-col_offset
ensures
- #subm(dest, row_offset, col_offset, src.nr(), src.nc()) == src
- the part of dest outside the above sub matrix remains unchanged
!*/
{
for (long c = 0; c < src.nc(); ++c)
for (long r = 0; r < src.nr(); ++r)
{
dest(r+row_offset,c+col_offset) = src(r,c);
for (long c = 0; c < src.nc(); ++c)
{
dest(r+row_offset,c+col_offset) = src(r,c);
}
}
}
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
......@@ -1524,7 +1437,7 @@ namespace dlib
data.set_size(m.nr(),m.nc());
matrix_assign(data, m);
matrix_assign(*this, m);
}
matrix (
......@@ -1532,7 +1445,7 @@ namespace dlib
): matrix_exp<matrix_ref<T,num_rows,num_cols, mem_manager> >(ref_type(*this))
{
data.set_size(m.nr(),m.nc());
matrix_assign(data, m);
matrix_assign(*this, m);
}
template <typename U, size_t len>
......@@ -1774,16 +1687,16 @@ namespace dlib
if (m.destructively_aliases(*this) == false)
{
set_size(m.nr(),m.nc());
matrix_assign(data, m);
matrix_assign(*this, m);
}
else
{
// we have to use a temporary matrix_data object here because
// this->data is aliased inside the matrix_exp m somewhere.
matrix_data<T,NR,NC, mem_manager> temp;
// we have to use a temporary matrix object here because
// *this is aliased inside the matrix_exp m somewhere.
matrix temp;
temp.set_size(m.nr(),m.nc());
matrix_assign(temp, m);
data.consume(temp);
temp.swap(*this);
}
return *this;
}
......@@ -1809,16 +1722,16 @@ namespace dlib
COMPILE_TIME_ASSERT((is_same_type<typename EXP::type,type>::value == true));
if (m.destructively_aliases(*this) == false)
{
matrix_assign(data, m + *this);
matrix_assign(*this, m + *this);
}
else
{
// we have to use a temporary matrix_data object here because
// this->data is aliased inside the matrix_exp m somewhere.
matrix_data<T,NR,NC, mem_manager> temp;
matrix temp;
temp.set_size(m.nr(),m.nc());
matrix_assign(temp, m + *this);
data.consume(temp);
temp.swap(*this);
}
return *this;
}
......@@ -1845,16 +1758,16 @@ namespace dlib
COMPILE_TIME_ASSERT((is_same_type<typename EXP::type,type>::value == true));
if (m.destructively_aliases(*this) == false)
{
matrix_assign(data, *this - m);
matrix_assign(*this, *this - m);
}
else
{
// we have to use a temporary matrix_data object here because
// this->data is aliased inside the matrix_exp m somewhere.
matrix_data<T,NR,NC, mem_manager> temp;
matrix temp;
temp.set_size(m.nr(),m.nc());
matrix_assign(temp, *this - m);
data.consume(temp);
temp.swap(*this);
}
return *this;
}
......
// Copyright (C) 2008 Davis E. King (davisking@users.sourceforge.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_MATRIx_ASSIGn_
#define DLIB_MATRIx_ASSIGn_
#include "../geometry.h"
#include "matrix.h"
#include "matrix_utilities.h"
#include "../enable_if.h"
namespace dlib
{
namespace ma
{
// ------------------------------------------------------------------------------------
template <
typename EXP
>
const matrix_exp<EXP> make_exp (
const EXP& exp
)
/*!
The only point of this function is to make it easy to cause the overloads
of matrix_assign to not trigger for a matrix expression.
!*/
{
return matrix_exp<EXP>(exp);
}
// ------------------------------------------------------------------------------------
template < typename EXP, typename enable = void >
struct matrix_is_vector { static const bool value = false; };
template < typename EXP >
struct matrix_is_vector<EXP, typename enable_if_c<EXP::NR==1 || EXP::NC==1>::type > { static const bool value = true; };
template < typename EXP, typename enable = void >
struct is_small_matrix { static const bool value = false; };
template < typename EXP >
struct is_small_matrix<EXP, typename enable_if_c<EXP::NR>=1 && EXP::NC>=1 &&
EXP::NR<=100 && EXP::NC<=100>::type > { static const bool value = true; };
}
// ----------------------------------------------------------------------------------------
template <
typename matrix_dest_type,
typename EXP1,
typename EXP2,
unsigned long count
>
inline typename disable_if_c<ma::matrix_is_vector<EXP1>::value || ma::matrix_is_vector<EXP2>::value ||
ma::is_small_matrix<EXP1>::value || ma::is_small_matrix<EXP2>::value >::type matrix_assign (
matrix_dest_type& dest,
const matrix_exp<matrix_multiply_exp<EXP1,EXP2,count> >& src,
const long row_offset = 0,
const long col_offset = 0
)
/*!
This overload catches assignments like:
dest = lhs*rhs
where lhs and rhs are both not vectors
!*/
{
using namespace ma;
const matrix_exp<EXP1> lhs(src.ref().lhs);
const matrix_exp<EXP2> rhs(src.ref().rhs);
const long bs = 100;
// if the matrices are small enough then just use the simple multiply algorithm
if (lhs.nc() <= 2 || rhs.nc() <= 2 || lhs.nr() <= 2 || rhs.nr() <= 2 || (lhs.size() <= bs*10 && rhs.size() <= bs*10) )
{
for (long r = 0; r < src.nr(); ++r)
{
for (long c = 0; c < src.nc(); ++c)
{
dest(r+row_offset,c+col_offset) = make_exp(lhs*rhs)(r,c);
}
}
}
else
{
// if the lhs and rhs matrices are big enough we should use a cache friendly
// algorithm that computes the matrix multiply in blocks.
// Loop over all the blocks in the lhs matrix
for (long r = 0; r < lhs.nr(); r+=bs)
{
for (long c = 0; c < lhs.nc(); c+=bs)
{
// make a rect for the block from lhs
rectangle lhs_block(c, r, c+bs-1, r+bs-1);
lhs_block = lhs_block.intersect(get_rect(lhs));
// now loop over all the rhs blocks we have to multiply with the current lhs block
for (long i = 0; i < rhs.nc(); i += bs)
{
// make a rect for the block from rhs
rectangle rhs_block(i, c, i+bs-1, c+bs-1);
rhs_block = rhs_block.intersect(get_rect(rhs));
// make a target rect in res
rectangle res_block(rhs_block.left(),lhs_block.top(), rhs_block.right(), lhs_block.bottom());
res_block = translate_rect(res_block,col_offset, row_offset);
if (c != 0)
set_subm(dest, res_block) = subm(dest,res_block) + subm(lhs,lhs_block)*subm(rhs, rhs_block);
else
set_subm(dest, res_block) = make_exp(subm(lhs,lhs_block)*subm(rhs, rhs_block));
}
}
}
}
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_MATRIx_ASSIGn_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment