Commit 88b37e6d authored by Davis King's avatar Davis King

Added the ability to add/subtract scalar values to/from all the elements of a matrix

using the - and + operators.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403909
parent fdb2b7d0
......@@ -14,6 +14,7 @@
#include "../is_kind.h"
#include "matrix_data_layout.h"
#include "matrix_assign_fwd.h"
#include "matrix_op.h"
#ifdef _MSC_VER
// Disable the following warnings for Visual Studio
......@@ -236,7 +237,7 @@ namespace dlib
//
// Also, the reason we want to apply this transformation in the first place is because it (1) makes
// the expressions going into matrix multiply expressions simpler and (2) it makes it a lot more
// straight forward to bind BLAS calls to matrix expressions involving scalar multiplies.
// straightforward to bind BLAS calls to matrix expressions involving scalar multiplies.
template < typename EXP1, typename EXP2 >
inline const typename disable_if_c< matrix_multiply_exp<matrix_mul_scal_exp<EXP1>, matrix_mul_scal_exp<EXP2> >::both_are_costly ,
matrix_mul_scal_exp<matrix_multiply_exp<EXP1, EXP2>,false> >::type operator* (
......@@ -764,6 +765,123 @@ namespace dlib
return matrix_mul_scal_exp<EXP>(m.m,-1*m.s);
}
// ----------------------------------------------------------------------------------------
template <typename M>
struct op_add_scalar : basic_op_m<M>
{
typedef typename M::type type;
op_add_scalar( const M& m_, const type& s_) : basic_op_m<M>(m_), s(s_){}
const type s;
const static long cost = M::cost+1;
typedef const typename M::type const_ret_type;
const_ret_type apply (long r, long c) const
{
return this->m(r,c) + s;
}
};
template <
typename EXP,
typename T
>
const typename disable_if<is_matrix<T>, matrix_op<op_add_scalar<EXP> > >::type operator+ (
const matrix_exp<EXP>& m,
const T& val
)
{
typedef typename EXP::type type;
typedef op_add_scalar<EXP> op;
return matrix_op<op>(op(m.ref(), static_cast<type>(val)));
}
template <
typename EXP,
typename T
>
const typename disable_if<is_matrix<T>, matrix_op<op_add_scalar<EXP> > >::type operator+ (
const T& val,
const matrix_exp<EXP>& m
)
{
typedef typename EXP::type type;
typedef op_add_scalar<EXP> op;
return matrix_op<op>(op(m.ref(), static_cast<type>(val)));
}
// ----------------------------------------------------------------------------------------
template <typename M>
struct op_subl_scalar : basic_op_m<M>
{
typedef typename M::type type;
op_subl_scalar( const M& m_, const type& s_) : basic_op_m<M>(m_), s(s_){}
const type s;
const static long cost = M::cost+1;
typedef const typename M::type const_ret_type;
const_ret_type apply (long r, long c) const
{
return s - this->m(r,c) ;
}
};
template <
typename EXP,
typename T
>
const typename disable_if<is_matrix<T>, matrix_op<op_subl_scalar<EXP> > >::type operator- (
const T& val,
const matrix_exp<EXP>& m
)
{
typedef typename EXP::type type;
typedef op_subl_scalar<EXP> op;
return matrix_op<op>(op(m.ref(), static_cast<type>(val)));
}
// ----------------------------------------------------------------------------------------
template <typename M>
struct op_subr_scalar : basic_op_m<M>
{
typedef typename M::type type;
op_subr_scalar( const M& m_, const type& s_) : basic_op_m<M>(m_), s(s_){}
const type s;
const static long cost = M::cost+1;
typedef const typename M::type const_ret_type;
const_ret_type apply (long r, long c) const
{
return this->m(r,c) - s;
}
};
template <
typename EXP,
typename T
>
const typename disable_if<is_matrix<T>, matrix_op<op_subr_scalar<EXP> > >::type operator- (
const matrix_exp<EXP>& m,
const T& val
)
{
typedef typename EXP::type type;
typedef op_subr_scalar<EXP> op;
return matrix_op<op>(op(m.ref(), static_cast<type>(val)));
}
// ----------------------------------------------------------------------------------------
template <
......@@ -1320,6 +1438,28 @@ namespace dlib
return *this;
}
matrix& operator += (
const T& val
)
{
const long size = nr()*nc();
for (long i = 0; i < size; ++i)
data(i) += val;
return *this;
}
matrix& operator -= (
const T& val
)
{
const long size = nr()*nc();
for (long i = 0; i < size; ++i)
data(i) -= val;
return *this;
}
matrix& operator *= (
const T& a
)
......
......@@ -108,6 +108,53 @@ namespace dlib
scalar value. The resulting matrix will have the same dimensions as m.
!*/
template <typename T>
const matrix_exp operator+ (
const matrix_exp& m,
const T& value
);
/*!
ensures
- returns the result of adding value to all the elements of matrix m.
The resulting matrix will have the same dimensions as m.
!*/
template <typename T>
const matrix_exp operator+ (
const T& value,
const matrix_exp& m
);
/*!
ensures
- returns the result of adding value to all the elements of matrix m.
The resulting matrix will have the same dimensions as m.
!*/
template <typename T>
const matrix_exp operator- (
const matrix_exp& m,
const T& value
);
/*!
ensures
- returns the result of subtracting value from all the elements of matrix m.
The resulting matrix will have the same dimensions as m.
!*/
template <typename T>
const matrix_exp operator- (
const T& value,
const matrix_exp& m
);
/*!
ensures
- Returns a matrix M such that:
- M has the same dimensions as m
- M contains the same type of element as m
- for all valid r and c:
- M(r,c) == value - m(r,c)
!*/
bool operator== (
const matrix_exp& m1,
const matrix_exp& m2
......@@ -457,6 +504,22 @@ namespace dlib
- returns *this
!*/
template <typename EXP>
matrix& operator *= (
const matrix_exp<EXP>& m
);
/*!
requires
- matrix_exp<EXP>::type == T
(i.e. m must contain the same type of element as *this)
- nc() == m.nr()
- size() > 0 && m.size() > 0
(you can't multiply any sort of empty matrices together)
ensures
- #(*this) == *this * m
- returns *this
!*/
matrix& operator *= (
const T& a
);
......@@ -475,6 +538,24 @@ namespace dlib
- returns *this
!*/
matrix& operator += (
const T& a
);
/*!
ensures
- #(*this) == *this + a
- returns *this
!*/
matrix& operator -= (
const T& a
);
/*!
ensures
- #(*this) == *this - a
- returns *this
!*/
const literal_assign_helper operator = (
const T& val
);
......
......@@ -214,6 +214,161 @@ namespace
b *= b;
DLIB_TEST(b == a*a);
}
{
matrix<double> m(2,3), m2(2,3);
m = 1,2,3,
4,5,6;
m2 = 3,4,5,
6,7,8;
DLIB_TEST(m + 2 == m2);
DLIB_TEST(2 + m == m2);
m += 2;
DLIB_TEST(m == m2);
m -= 2;
m2 = 0,1,2,
3,4,5;
DLIB_TEST(m - 1 == m2);
m -= 1;
DLIB_TEST(m == m2);
m += 1;
m2 = 5,4,3,
2,1,0;
DLIB_TEST(6 - m == m2);
}
{
matrix<float> m(2,3), m2(2,3);
m = 1,2,3,
4,5,6;
m2 = 3,4,5,
6,7,8;
DLIB_TEST(m + 2 == m2);
DLIB_TEST(2 + m == m2);
m += 2;
DLIB_TEST(m == m2);
m -= 2;
m2 = 0,1,2,
3,4,5;
DLIB_TEST(m - 1 == m2);
m -= 1;
DLIB_TEST(m == m2);
m += 1;
m2 = 5,4,3,
2,1,0;
DLIB_TEST(6 - m == m2);
}
{
matrix<int> m(2,3), m2(2,3);
m = 1,2,3,
4,5,6;
m2 = 3,4,5,
6,7,8;
DLIB_TEST(m + 2 == m2);
DLIB_TEST(2 + m == m2);
m += 2;
DLIB_TEST(m == m2);
m -= 2;
m2 = 0,1,2,
3,4,5;
DLIB_TEST(m - 1 == m2);
m -= 1;
DLIB_TEST(m == m2);
m += 1;
m2 = 5,4,3,
2,1,0;
DLIB_TEST(6 - m == m2);
}
{
matrix<int,2,3> m, m2;
m = 1,2,3,
4,5,6;
m2 = 3,4,5,
6,7,8;
DLIB_TEST(m + 2 == m2);
DLIB_TEST(2 + m == m2);
m += 2;
DLIB_TEST(m == m2);
m -= 2;
m2 = 0,1,2,
3,4,5;
DLIB_TEST(m - 1 == m2);
m -= 1;
DLIB_TEST(m == m2);
m += 1;
m2 = 5,4,3,
2,1,0;
DLIB_TEST(6 - m == m2);
}
{
matrix<double> m(2,3), m2(3,2);
m = 1,2,3,
4,5,6;
m2 = 2,5,
3,6,
4,7;
DLIB_TEST(trans(m+1) == m2);
DLIB_TEST(trans(m)+1 == m2);
DLIB_TEST(1+trans(m) == m2);
DLIB_TEST(1+m-1 == m);
m = trans(m+1);
DLIB_TEST(m == m2);
m = trans(m-1);
DLIB_TEST(trans(m+1) == m2);
m = trans(m)+1;
DLIB_TEST(m == m2);
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment