Commit cae2fb28 authored by Davis King's avatar Davis King

Added the range() function and an overload of subm() that allows you to

pick out slices of matrices like in Matlab.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402696
parent f9f34de5
......@@ -1803,6 +1803,153 @@ convergence:
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
class matrix_range_exp
{
public:
typedef long type;
typedef matrix_range_exp ref_type;
typedef memory_manager<char>::kernel_1a mem_manager_type;
const static long NR = 0;
const static long NC = 1;
matrix_range_exp (
long start_,
long end_
)
{
start = start_;
inc = 1;
nr_ = end_ - start_ + 1;
}
matrix_range_exp (
long start_,
long inc_,
long end_
)
{
start = start_;
inc = inc_;
nr_ = (end_ - start_)/inc_ + 1;
}
const long operator() (
long r,
long c
) const { return start + r*inc; }
template <typename U, long iNR, long iNC , typename MM>
bool aliases (
const matrix<U,iNR,iNC,MM>& item
) const { return false; }
template <typename U, long iNR, long iNC, typename MM >
bool destructively_aliases (
const matrix<U,iNR,iNC,MM>& item
) const { return false; }
long nr (
) const { return nr_; }
long nc (
) const { return NC; }
const ref_type& ref(
) const { return *this; }
long nr_;
long start;
long inc;
};
// ----------------------------------------------------------------------------------------
template <long start, long inc, long end>
class matrix_range_static_exp
{
public:
typedef long type;
typedef matrix_range_static_exp ref_type;
typedef memory_manager<char>::kernel_1a mem_manager_type;
const static long NR = (end - start)/inc + 1;
const static long NC = 1;
const long operator() (
long r,
long c
) const { return start + r*inc; }
template <typename U, long iNR, long iNC , typename MM>
bool aliases (
const matrix<U,iNR,iNC,MM>& item
) const { return false; }
template <typename U, long iNR, long iNC, typename MM >
bool destructively_aliases (
const matrix<U,iNR,iNC,MM>& item
) const { return false; }
long nr (
) const { return NR; }
long nc (
) const { return NC; }
const ref_type& ref(
) const { return *this; }
};
// ----------------------------------------------------------------------------------------
template <long start, long inc, long end>
const matrix_exp<matrix_range_static_exp<start,inc,end> > range (
)
{
COMPILE_TIME_ASSERT(start <= end);
return matrix_exp<matrix_range_static_exp<start,inc,end> >(matrix_range_static_exp<start,inc,end>());
}
template <long start, long end>
const matrix_exp<matrix_range_static_exp<start,1,end> > range (
)
{
COMPILE_TIME_ASSERT(start <= end);
return matrix_exp<matrix_range_static_exp<start,1,end> >(matrix_range_static_exp<start,1,end>());
}
inline const matrix_exp<matrix_range_exp> range (
long start,
long end
)
{
DLIB_ASSERT(start <= end,
"\tconst matrix_exp range(start, end)"
<< "\n\tstart can't be bigger than end"
<< "\n\tstart: " << start
<< "\n\tend: " << end
);
return matrix_exp<matrix_range_exp>(matrix_range_exp(start,end));
}
inline const matrix_exp<matrix_range_exp> range (
long start,
long inc,
long end
)
{
DLIB_ASSERT(start <= end,
"\tconst matrix_exp range(start, inc, end)"
<< "\n\tstart can't be bigger than end"
<< "\n\tstart: " << start
<< "\n\tend: " << end
);
return matrix_exp<matrix_range_exp>(matrix_range_exp(start,inc,end));
}
// ----------------------------------------------------------------------------------------
template <
......@@ -1949,6 +2096,105 @@ convergence:
return matrix_exp<exp>(exp(m,rect.top(),rect.left(),rect.height(),rect.width()));
}
// ----------------------------------------------------------------------------------------
template <
typename M,
typename EXPr,
typename EXPc
>
class matrix_sub_range_exp
{
/*!
REQUIREMENTS ON M, EXPr and EXPc
- must be a matrix_exp or matrix_ref object (or
an object with a compatible interface).
!*/
public:
typedef typename M::type type;
typedef matrix_sub_range_exp ref_type;
typedef typename M::mem_manager_type mem_manager_type;
const static long NR = EXPr::NR*EXPr::NC;
const static long NC = EXPc::NR*EXPr::NC;
matrix_sub_range_exp (
const M& m_,
const EXPr& rows_,
const EXPc& cols_
) :
m(m_),
rows(rows_),
cols(cols_)
{
}
const typename M::type operator() (
long r,
long c
) const { return m(rows(r),cols(c)); }
template <typename U, long iNR, long iNC, typename MM >
bool aliases (
const matrix<U,iNR,iNC,MM>& item
) const { return m.aliases(item) || rows.aliases(item) || cols.aliases(item); }
template <typename U, long iNR, long iNC , typename MM>
bool destructively_aliases (
const matrix<U,iNR,iNC,MM>& item
) const { return m.aliases(item) || rows.aliases(item) || cols.aliases(item); }
const ref_type& ref(
) const { return *this; }
long nr (
) const { return rows.size(); }
long nc (
) const { return cols.size(); }
private:
const M m;
EXPr rows;
EXPc cols;
};
template <
typename EXP,
typename EXPr,
typename EXPc
>
const matrix_exp<matrix_sub_range_exp<matrix_exp<EXP>,matrix_exp<EXPr>,matrix_exp<EXPc> > > subm (
const matrix_exp<EXP>& m,
const matrix_exp<EXPr>& rows,
const matrix_exp<EXPc>& cols
)
{
// the rows and cols matrices must contain elements of type long
COMPILE_TIME_ASSERT((is_same_type<typename EXPr::type,long>::value == true));
COMPILE_TIME_ASSERT((is_same_type<typename EXPc::type,long>::value == true));
DLIB_ASSERT(0 <= min(rows) && max(rows) < m.nr() && 0 <= min(cols) && max(cols) < m.nc() &&
(rows.nr() == 1 || rows.nc() == 1) && (cols.nr() == 1 || cols.nc() == 1),
"\tconst matrix_exp subm(const matrix_exp& m, const matrix_exp& rows, const matrix_exp& cols)"
<< "\n\tYou have given invalid arguments to this function"
<< "\n\tm.nr(): " << m.nr()
<< "\n\tm.nc(): " << m.nc()
<< "\n\tmin(rows): " << min(rows)
<< "\n\tmax(rows): " << max(rows)
<< "\n\tmin(cols): " << min(cols)
<< "\n\tmax(cols): " << max(cols)
<< "\n\trows.nr(): " << rows.nr()
<< "\n\trows.nc(): " << rows.nc()
<< "\n\tcols.nr(): " << cols.nr()
<< "\n\tcols.nc(): " << cols.nc()
);
typedef matrix_sub_range_exp<matrix_exp<EXP>,matrix_exp<EXPr>,matrix_exp<EXPc> > exp;
return matrix_exp<exp>(exp(m,rows,cols));
}
// ----------------------------------------------------------------------------------------
struct op_rowm
......
......@@ -213,6 +213,71 @@ namespace dlib
the matrix m)
!*/
// ----------------------------------------------------------------------------------------
template <long start, long inc, long end>
const matrix_exp range (
);
/*!
requires
- start <= end
ensures
- returns a matrix R such that:
- R::type == long
- R.nr() == (end - start)/inc + 1
- R.nc() == 1
- R(i) == start + i*inc
!*/
template <long start, long end>
const matrix_exp range (
) { return range<start,1,end>(); }
const matrix_exp range (
long start,
long inc,
long end
);
/*!
requires
- start <= end
ensures
- returns a matrix R such that:
- R::type == long
- R.nr() == (end - start)/inc + 1
- R.nc() == 1
- R(i) == start + i*inc
!*/
const matrix_exp range (
long start,
long end
) { return range(start,1,end); }
// ----------------------------------------------------------------------------------------
const matrix_exp subm (
const matrix_exp& m,
const matrix_exp& rows,
const matrix_exp& cols,
);
/*!
requires
- rows and cols contain elements of type long
- 0 <= min(rows) && max(rows) < m.nr()
- 0 <= min(cols) && max(cols) < m.nc()
- rows.nr() == 1 || rows.nc() == 1
- cols.nr() == 1 || cols.nc() == 1
(i.e. rows and cols must be vectors)
ensures
- returns a matrix R such that:
- R::type == the same type that was in m
- R.nr() == rows.size()
- R.nc() == cols.size()
- for all valid r and c:
R(r,c) == m(rows(r),cols(c))
!*/
// ----------------------------------------------------------------------------------------
const matrix_exp subm (
......
......@@ -1615,6 +1615,139 @@ namespace
DLIB_CASSERT(diagm(i) == m,"");
}
{
DLIB_CASSERT(range(0,5).nr() == 6,"");
DLIB_CASSERT(range(1,5).nr() == 5,"");
DLIB_CASSERT(range(0,5).nc() == 1,"");
DLIB_CASSERT(range(1,5).nc() == 1,"");
DLIB_CASSERT(trans(range(0,5)).nc() == 6,"");
DLIB_CASSERT(trans(range(1,5)).nc() == 5,"");
DLIB_CASSERT(trans(range(0,5)).nr() == 1,"");
DLIB_CASSERT(trans(range(1,5)).nr() == 1,"");
DLIB_CASSERT(range(0,2,5).nr() == 3,"");
DLIB_CASSERT(range(1,2,5).nr() == 3,"");
DLIB_CASSERT(range(0,2,5).nc() == 1,"");
DLIB_CASSERT(range(1,2,5).nc() == 1,"");
DLIB_CASSERT(trans(range(0,2,5)).nc() == 3,"");
DLIB_CASSERT(trans(range(1,2,5)).nc() == 3,"");
DLIB_CASSERT(trans(range(0,2,5)).nr() == 1,"");
DLIB_CASSERT(trans(range(1,2,5)).nr() == 1,"");
DLIB_CASSERT(range(0,3,6).nr() == 3,"");
DLIB_CASSERT(range(1,3,5).nr() == 2,"");
DLIB_CASSERT(range(0,3,5).nc() == 1,"");
DLIB_CASSERT(range(1,3,5).nc() == 1,"");
DLIB_CASSERT(trans(range(0,3,6)).nc() == 3,"");
DLIB_CASSERT(trans(range(1,3,5)).nc() == 2,"");
DLIB_CASSERT(trans(range(0,3,5)).nr() == 1,"");
DLIB_CASSERT(trans(range(1,3,5)).nr() == 1,"");
DLIB_CASSERT(range(1,9,5).nr() == 1,"");
DLIB_CASSERT(range(1,9,5).nc() == 1,"");
DLIB_CASSERT(range(0,0).nr() == 1,"");
DLIB_CASSERT(range(0,0).nc() == 1,"");
DLIB_CASSERT(range(1,1)(0) == 1,"");
DLIB_CASSERT(range(0,5)(0) == 0 && range(0,5)(1) == 1 && range(0,5)(5) == 5,"");
DLIB_CASSERT(range(1,2,5)(0) == 1 && range(1,2,5)(1) == 3 && range(1,2,5)(2) == 5,"");
DLIB_CASSERT((range<0,5>()(0) == 0 && range<0,5>()(1) == 1 && range<0,5>()(5) == 5),"");
DLIB_CASSERT((range<1,2,5>()(0) == 1 && range<1,2,5>()(1) == 3 && range<1,2,5>()(2) == 5),"");
DLIB_CASSERT((range<0,5>().nr() == 6),"");
DLIB_CASSERT((range<1,5>().nr() == 5),"");
DLIB_CASSERT((range<0,5>().nc() == 1),"");
DLIB_CASSERT((range<1,5>().nc() == 1),"");
DLIB_CASSERT((trans(range<0,5>()).nc() == 6),"");
DLIB_CASSERT((trans(range<1,5>()).nc() == 5),"");
DLIB_CASSERT((trans(range<0,5>()).nr() == 1),"");
DLIB_CASSERT((trans(range<1,5>()).nr() == 1),"");
DLIB_CASSERT((range<0,2,5>().nr() == 3),"");
DLIB_CASSERT((range<1,2,5>().nr() == 3),"");
DLIB_CASSERT((range<0,2,5>().nc() == 1),"");
DLIB_CASSERT((range<1,2,5>().nc() == 1),"");
DLIB_CASSERT((trans(range<0,2,5>()).nc() == 3),"");
DLIB_CASSERT((trans(range<1,2,5>()).nc() == 3),"");
DLIB_CASSERT((trans(range<0,2,5>()).nr() == 1),"");
DLIB_CASSERT((trans(range<1,2,5>()).nr() == 1),"");
DLIB_CASSERT((range<0,3,6>().nr() == 3),"");
DLIB_CASSERT((range<1,3,5>().nr() == 2),"");
DLIB_CASSERT((range<0,3,5>().nc() == 1),"");
DLIB_CASSERT((range<1,3,5>().nc() == 1),"");
DLIB_CASSERT((trans(range<0,3,6>()).nc() == 3),"");
DLIB_CASSERT((trans(range<1,3,5>()).nc() == 2),"");
DLIB_CASSERT((trans(range<0,3,5>()).nr() == 1),"");
DLIB_CASSERT((trans(range<1,3,5>()).nr() == 1),"");
}
{
matrix<double> m(4,3);
for (long r = 0; r < m.nr(); ++r)
{
for (long c = 0; c < m.nc(); ++c)
{
m(r,c) = r*c;
}
}
DLIB_CASSERT(subm(m,range(0,3),range(0,0)) == colm(m,0),"");
DLIB_CASSERT(subm(m,range(0,3),range(1,1)) == colm(m,1),"");
DLIB_CASSERT(subm(m,range(0,3),range(2,2)) == colm(m,2),"");
DLIB_CASSERT(subm(m,range(0,0),range(0,2)) == rowm(m,0),"");
DLIB_CASSERT(subm(m,range(1,1),range(0,2)) == rowm(m,1),"");
DLIB_CASSERT(subm(m,range(2,2),range(0,2)) == rowm(m,2),"");
DLIB_CASSERT(subm(m,range(3,3),range(0,2)) == rowm(m,3),"");
DLIB_CASSERT(subm(m,0,0,2,2) == subm(m,range(0,1),range(0,1)),"");
DLIB_CASSERT(subm(m,1,1,2,2) == subm(m,range(1,2),range(1,2)),"");
matrix<double,2,2> m2 = subm(m,range(0,2,2),range(0,2,2));
DLIB_CASSERT(m2(0,0) == m(0,0),"");
DLIB_CASSERT(m2(0,1) == m(0,2),"");
DLIB_CASSERT(m2(1,0) == m(2,0),"");
DLIB_CASSERT(m2(1,1) == m(2,2),"");
}
{
matrix<double,4,3> m(4,3);
for (long r = 0; r < m.nr(); ++r)
{
for (long c = 0; c < m.nc(); ++c)
{
m(r,c) = r*c;
}
}
DLIB_CASSERT(subm(m,range<0,3>(),range<0,0>()) == colm(m,0),"");
DLIB_CASSERT(subm(m,range<0,3>(),range<1,1>()) == colm(m,1),"");
DLIB_CASSERT(subm(m,range<0,3>(),range<2,2>()) == colm(m,2),"");
DLIB_CASSERT(subm(m,range<0,0>(),range<0,2>()) == rowm(m,0),"");
DLIB_CASSERT(subm(m,range<1,1>(),range<0,2>()) == rowm(m,1),"");
DLIB_CASSERT(subm(m,range<2,2>(),range<0,2>()) == rowm(m,2),"");
DLIB_CASSERT(subm(m,range<3,3>(),range<0,2>()) == rowm(m,3),"");
DLIB_CASSERT(subm(m,0,0,2,2) == subm(m,range<0,1>(),range<0,1>()),"");
DLIB_CASSERT(subm(m,1,1,2,2) == subm(m,range<1,2>(),range<1,2>()),"");
matrix<double,2,2> m2 = subm(m,range<0,2,2>(),range<0,2,2>());
DLIB_CASSERT(m2(0,0) == m(0,0),"");
DLIB_CASSERT(m2(0,1) == m(0,2),"");
DLIB_CASSERT(m2(1,0) == m(2,0),"");
DLIB_CASSERT(m2(1,1) == m(2,2),"");
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment