Commit 4ec03e7b authored by Davis King's avatar Davis King

Added join_rows() and join_cols()

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403580
parent 003bbd70
......@@ -2737,6 +2737,116 @@ namespace dlib
return exp(m.ref());
}
// ----------------------------------------------------------------------------------------
struct op_join_rows
{
template <typename EXP1, typename EXP2>
struct op : public has_destructive_aliasing
{
const static long cost = EXP1::cost + EXP2::cost + 1;
const static long NR = tmax<EXP1::NR, EXP2::NR>::value;
const static long NC = (EXP1::NC*EXP2::NC != 0)? (EXP1::NC+EXP2::NC) : (0);
typedef typename EXP1::type type;
typedef typename EXP1::const_ret_type const_ret_type;
typedef typename EXP1::mem_manager_type mem_manager_type;
template <typename M1, typename M2>
static const_ret_type apply ( const M1& m1, const M2& m2 , long r, long c)
{
if (c < m1.nc())
return m1(r,c);
else
return m2(r,c-m1.nc());
}
template <typename M1, typename M2>
static long nr (const M1& m1, const M2& ) { return m1.nr(); }
template <typename M1, typename M2>
static long nc (const M1& m1, const M2& m2 ) { return m1.nc()+m2.nc(); }
};
};
template <
typename EXP1,
typename EXP2
>
inline const matrix_binary_exp<EXP1,EXP2,op_join_rows> join_rows (
const matrix_exp<EXP1>& a,
const matrix_exp<EXP2>& b
)
{
COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
// You are getting an error on this line because you are trying to join two matrices that
// don't have the same number of rows
COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || (EXP1::NR*EXP2::NR == 0));
DLIB_ASSERT(a.nr() == b.nr(),
"\tconst matrix_exp join_rows(const matrix_exp& a, const matrix_exp& b)"
<< "\n\tYou can only use join_rows() if both matrices have the same number of rows"
<< "\n\ta.nr(): " << a.nr()
<< "\n\tb.nr(): " << b.nr()
);
typedef matrix_binary_exp<EXP1,EXP2,op_join_rows> exp;
return exp(a.ref(),b.ref());
}
// ----------------------------------------------------------------------------------------
struct op_join_cols
{
template <typename EXP1, typename EXP2>
struct op : public has_destructive_aliasing
{
const static long cost = EXP1::cost + EXP2::cost + 1;
const static long NC = tmax<EXP1::NC, EXP2::NC>::value;
const static long NR = (EXP1::NR*EXP2::NR != 0)? (EXP1::NR+EXP2::NR) : (0);
typedef typename EXP1::type type;
typedef typename EXP1::const_ret_type const_ret_type;
typedef typename EXP1::mem_manager_type mem_manager_type;
template <typename M1, typename M2>
static const_ret_type apply ( const M1& m1, const M2& m2 , long r, long c)
{
if (r < m1.nr())
return m1(r,c);
else
return m2(r-m1.nr(),c);
}
template <typename M1, typename M2>
static long nr (const M1& m1, const M2& m2 ) { return m1.nr()+m2.nr(); }
template <typename M1, typename M2>
static long nc (const M1& m1, const M2& ) { return m1.nc(); }
};
};
template <
typename EXP1,
typename EXP2
>
inline const matrix_binary_exp<EXP1,EXP2,op_join_cols> join_cols (
const matrix_exp<EXP1>& a,
const matrix_exp<EXP2>& b
)
{
COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
// You are getting an error on this line because you are trying to join two matrices that
// don't have the same number of columns
COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || (EXP1::NC*EXP2::NC == 0));
DLIB_ASSERT(a.nc() == b.nc(),
"\tconst matrix_exp join_cols(const matrix_exp& a, const matrix_exp& b)"
<< "\n\tYou can only use join_cols() if both matrices have the same number of columns"
<< "\n\ta.nc(): " << a.nc()
<< "\n\tb.nc(): " << b.nc()
);
typedef matrix_binary_exp<EXP1,EXP2,op_join_cols> exp;
return exp(a.ref(),b.ref());
}
// ----------------------------------------------------------------------------------------
}
......
......@@ -632,6 +632,52 @@ namespace dlib
performs pointwise_multiply(pointwise_multiply(a,b),pointwise_multiply(c,d));
!*/
// ----------------------------------------------------------------------------------------
const matrix_exp join_rows (
const matrix_exp& a,
const matrix_exp& b
);
/*!
requires
- a.nr() == b.nr()
- a and b both contain the same type of element
ensures
- This function joins two matrices together by concatenating their rows.
- returns a matrix R such that:
- R::type == the same type that was in a and b.
- R.nr() == a.nr() == b.nr()
- R.nc() == a.nr() + b.nc()
- for all valid r and c:
- if (c < a.nc()) then
- R(r,c) == a(r,c)
- else
- R(r,c) == b(r, c-a.nc())
!*/
// ----------------------------------------------------------------------------------------
const matrix_exp join_cols (
const matrix_exp& a,
const matrix_exp& b
);
/*!
requires
- a.nr() == b.nr()
- a and b both contain the same type of element
ensures
- This function joins two matrices together by concatenating their columns.
- returns a matrix R such that:
- R::type == the same type that was in a and b.
- R.nr() == a.nr() + b.nr()
- R.nc() == a.nr() == b.nc()
- for all valid r and c:
- if (r < a.nr()) then
- R(r,c) == a(r,c)
- else
- R(r,c) == b(r-a.nr(), c)
!*/
// ----------------------------------------------------------------------------------------
const matrix_exp tensor_product (
......
......@@ -809,6 +809,79 @@ namespace
DLIB_TEST(equal(-2*a, c));
}
{
matrix<int> a, b, c;
a.set_size(2, 3);
b.set_size(2, 6);
c.set_size(4, 3);
a = 1, 2, 3,
4, 5, 6;
b = 1, 2, 3, 1, 2, 3,
4, 5, 6, 4, 5, 6;
c = 1, 2, 3,
4, 5, 6,
1, 2, 3,
4, 5, 6;
DLIB_TEST(join_rows(a,a) == b);
DLIB_TEST(join_cols(trans(a), trans(a)) == trans(b));
DLIB_TEST(join_cols(a,a) == c)
DLIB_TEST(join_rows(trans(a),trans(a)) == trans(c))
}
{
matrix<int, 2, 3> a;
matrix<int, 2, 6> b;
matrix<int, 4, 3> c;
a = 1, 2, 3,
4, 5, 6;
b = 1, 2, 3, 1, 2, 3,
4, 5, 6, 4, 5, 6;
c = 1, 2, 3,
4, 5, 6,
1, 2, 3,
4, 5, 6;
DLIB_TEST(join_rows(a,a) == b);
DLIB_TEST(join_cols(trans(a), trans(a)) == trans(b));
DLIB_TEST(join_cols(a,a) == c)
DLIB_TEST(join_rows(trans(a),trans(a)) == trans(c))
}
{
matrix<int, 2, 3> a;
matrix<int> a2;
matrix<int, 2, 6> b;
matrix<int, 4, 3> c;
a = 1, 2, 3,
4, 5, 6;
a2 = a;
b = 1, 2, 3, 1, 2, 3,
4, 5, 6, 4, 5, 6;
c = 1, 2, 3,
4, 5, 6,
1, 2, 3,
4, 5, 6;
DLIB_TEST(join_rows(a,a2) == b);
DLIB_TEST(join_rows(a2,a) == b);
DLIB_TEST(join_cols(trans(a2), trans(a)) == trans(b));
DLIB_TEST(join_cols(a2,a) == c)
DLIB_TEST(join_cols(a,a2) == c)
DLIB_TEST(join_rows(trans(a2),trans(a)) == trans(c))
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment