Commit 0f89c0f6 authored by Davis King's avatar Davis King

- Changed the range() stuff so that it works with ranges of the form

    range(1,5) as well as range(5,1)
  - Added overloads of colm(), rowm(), set_colm(), and set_rowm() that
    can work with ranges.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402722
parent 53ea098d
......@@ -1833,8 +1833,11 @@ convergence:
)
{
start = start_;
if (start_ <= end_)
inc = 1;
nr_ = end_ - start_ + 1;
else
inc = -1;
nr_ = std::abs(end_ - start_) + 1;
}
matrix_range_exp (
long start_,
......@@ -1843,8 +1846,11 @@ convergence:
)
{
start = start_;
nr_ = std::abs(end_ - start_)/inc_ + 1;
if (start_ <= end_)
inc = inc_;
nr_ = (end_ - start_)/inc_ + 1;
else
inc = -inc_;
}
long operator() (
......@@ -1879,14 +1885,15 @@ convergence:
// ----------------------------------------------------------------------------------------
template <long start, long inc, long end>
template <long start, long inc_, long end>
class matrix_range_static_exp
{
public:
typedef long type;
typedef matrix_range_static_exp ref_type;
typedef memory_manager<char>::kernel_1a mem_manager_type;
const static long NR = (end - start)/inc + 1;
const static long inc = (start <= end)?inc_:-inc_;
const static long NR = tabs<(end - start)>::value/inc_ + 1;
const static long NC = 1;
long operator() (
......@@ -1920,7 +1927,7 @@ convergence:
const matrix_exp<matrix_range_static_exp<start,inc,end> > range (
)
{
COMPILE_TIME_ASSERT(start <= end);
COMPILE_TIME_ASSERT(inc > 0);
return matrix_exp<matrix_range_static_exp<start,inc,end> >(matrix_range_static_exp<start,inc,end>());
}
......@@ -1928,7 +1935,6 @@ convergence:
const matrix_exp<matrix_range_static_exp<start,1,end> > range (
)
{
COMPILE_TIME_ASSERT(start <= end);
return matrix_exp<matrix_range_static_exp<start,1,end> >(matrix_range_static_exp<start,1,end>());
}
......@@ -1937,13 +1943,6 @@ convergence:
long end
)
{
DLIB_ASSERT(start <= end,
"\tconst matrix_exp range(start, end)"
<< "\n\tstart can't be bigger than end"
<< "\n\tstart: " << start
<< "\n\tend: " << end
);
return matrix_exp<matrix_range_exp>(matrix_range_exp(start,end));
}
......@@ -1953,10 +1952,11 @@ convergence:
long end
)
{
DLIB_ASSERT(start <= end,
DLIB_ASSERT(inc > 0,
"\tconst matrix_exp range(start, inc, end)"
<< "\n\tstart can't be bigger than end"
<< "\n\tstart: " << start
<< "\n\tinc: " << inc
<< "\n\tend: " << end
);
......@@ -2250,6 +2250,56 @@ convergence:
return matrix_exp<exp>(exp(m,row));
}
// ----------------------------------------------------------------------------------------
struct op_rowm_range
{
template <typename EXP1, typename EXP2>
struct op : has_destructive_aliasing
{
typedef typename EXP1::type type;
typedef typename EXP1::mem_manager_type mem_manager_type;
const static long NR = EXP2::NC*EXP2::NR;
const static long NC = EXP1::NC;
template <typename M1, typename M2>
static type apply ( const M1& m1, const M2& rows , long r, long c)
{ return m1(rows(r),c); }
template <typename M1, typename M2>
static long nr (const M1& m1, const M2& rows ) { return rows.size(); }
template <typename M1, typename M2>
static long nc (const M1& m1, const M2& ) { return m1.nc(); }
};
};
template <
typename EXP1,
typename EXP2
>
const matrix_exp<matrix_binary_exp<matrix_exp<EXP1>,matrix_exp<EXP2>,op_rowm_range> > rowm (
const matrix_exp<EXP1>& m,
const matrix_exp<EXP2>& rows
)
{
// the rows matrix must contain elements of type long
COMPILE_TIME_ASSERT((is_same_type<typename EXP2::type,long>::value == true));
DLIB_ASSERT(0 <= min(rows) && max(rows) < m.nr() && (rows.nr() == 1 || rows.nc() == 1),
"\tconst matrix_exp rowm(const matrix_exp& m, const matrix_exp& rows)"
<< "\n\tYou have given invalid arguments to this function"
<< "\n\tm.nr(): " << m.nr()
<< "\n\tm.nc(): " << m.nc()
<< "\n\tmin(rows): " << min(rows)
<< "\n\tmax(rows): " << max(rows)
<< "\n\trows.nr(): " << rows.nr()
<< "\n\trows.nc(): " << rows.nc()
);
typedef matrix_binary_exp<matrix_exp<EXP1>,matrix_exp<EXP2>,op_rowm_range> exp;
return matrix_exp<exp>(exp(m,rows));
}
// ----------------------------------------------------------------------------------------
struct op_colm
......@@ -2292,6 +2342,56 @@ convergence:
return matrix_exp<exp>(exp(m,col));
}
// ----------------------------------------------------------------------------------------
struct op_colm_range
{
template <typename EXP1, typename EXP2>
struct op : has_destructive_aliasing
{
typedef typename EXP1::type type;
typedef typename EXP1::mem_manager_type mem_manager_type;
const static long NR = EXP1::NR;
const static long NC = EXP2::NC*EXP2::NR;
template <typename M1, typename M2>
static type apply ( const M1& m1, const M2& cols , long r, long c)
{ return m1(r,cols(c)); }
template <typename M1, typename M2>
static long nr (const M1& m1, const M2& cols ) { return m1.nr(); }
template <typename M1, typename M2>
static long nc (const M1& m1, const M2& cols ) { return cols.size(); }
};
};
template <
typename EXP1,
typename EXP2
>
const matrix_exp<matrix_binary_exp<matrix_exp<EXP1>,matrix_exp<EXP2>,op_colm_range> > colm (
const matrix_exp<EXP1>& m,
const matrix_exp<EXP2>& cols
)
{
// the cols matrix must contain elements of type long
COMPILE_TIME_ASSERT((is_same_type<typename EXP2::type,long>::value == true));
DLIB_ASSERT(0 <= min(cols) && max(cols) < m.nc() && (cols.nr() == 1 || cols.nc() == 1),
"\tconst matrix_exp colm(const matrix_exp& m, const matrix_exp& cols)"
<< "\n\tYou have given invalid arguments to this function"
<< "\n\tm.nr(): " << m.nr()
<< "\n\tm.nc(): " << m.nc()
<< "\n\tmin(cols): " << min(cols)
<< "\n\tmax(cols): " << max(cols)
<< "\n\tcols.nr(): " << cols.nr()
<< "\n\tcols.nc(): " << cols.nc()
);
typedef matrix_binary_exp<matrix_exp<EXP1>,matrix_exp<EXP2>,op_colm_range> exp;
return matrix_exp<exp>(exp(m,cols));
}
// ----------------------------------------------------------------------------------------
......@@ -2510,6 +2610,50 @@ convergence:
return assignable_sub_range_matrix<T,NR,NC,mm,l,matrix_exp<EXPr>,matrix_exp<EXPc> >(m,rows,cols);
}
// ----------------------------------------------------------------------------------------
template <typename T, long NR, long NC, typename mm, typename l, typename EXPr>
assignable_sub_range_matrix<T,NR,NC,mm,l,matrix_exp<EXPr>,matrix_exp<matrix_range_exp> > set_rowm (
matrix<T,NR,NC,mm,l>& m,
const matrix_exp<EXPr>& rows
)
{
DLIB_ASSERT(0 <= min(rows) && max(rows) < m.nr() && (rows.nr() == 1 || rows.nc() == 1),
"\tassignable_matrix_expression set_rowm(matrix& m, const matrix_exp& rows)"
<< "\n\tYou have specified invalid sub matrix dimensions"
<< "\n\tm.nr(): " << m.nr()
<< "\n\tm.nc(): " << m.nc()
<< "\n\tmin(rows): " << min(rows)
<< "\n\tmax(rows): " << max(rows)
<< "\n\trows.nr(): " << rows.nr()
<< "\n\trows.nc(): " << rows.nc()
);
return assignable_sub_range_matrix<T,NR,NC,mm,l,matrix_exp<EXPr>,matrix_exp<matrix_range_exp> >(m,rows,range(0,m.nc()-1));
}
// ----------------------------------------------------------------------------------------
template <typename T, long NR, long NC, typename mm, typename l, typename EXPc>
assignable_sub_range_matrix<T,NR,NC,mm,l,matrix_exp<matrix_range_exp>,matrix_exp<EXPc> > set_colm (
matrix<T,NR,NC,mm,l>& m,
const matrix_exp<EXPc>& cols
)
{
DLIB_ASSERT(0 <= min(cols) && max(cols) < m.nc() && (cols.nr() == 1 || cols.nc() == 1),
"\tassignable_matrix_expression set_colm(matrix& m, const matrix_exp& cols)"
<< "\n\tYou have specified invalid sub matrix dimensions"
<< "\n\tm.nr(): " << m.nr()
<< "\n\tm.nc(): " << m.nc()
<< "\n\tmin(cols): " << min(cols)
<< "\n\tmax(cols): " << max(cols)
<< "\n\tcols.nr(): " << cols.nr()
<< "\n\tcols.nc(): " << cols.nc()
);
return assignable_sub_range_matrix<T,NR,NC,mm,l,matrix_exp<matrix_range_exp>,matrix_exp<EXPc> >(m,range(0,m.nr()-1),cols);
}
// ----------------------------------------------------------------------------------------
template <typename T, long NR, long NC, typename mm, typename l>
......
......@@ -218,13 +218,16 @@ namespace dlib
);
/*!
requires
- start <= end
- inc > 0
ensures
- returns a matrix R such that:
- R::type == long
- R.nr() == (end - start)/inc + 1
- R.nr() == abs(end - start)/inc + 1
- R.nc() == 1
- if (start <= end) then
- R(i) == start + i*inc
- else
- R(i) == start - i*inc
!*/
template <long start, long end>
......@@ -238,13 +241,16 @@ namespace dlib
);
/*!
requires
- start <= end
- inc > 0
ensures
- returns a matrix R such that:
- R::type == long
- R.nr() == (end - start)/inc + 1
- R.nr() == abs(end - start)/inc + 1
- R.nc() == 1
- if (start <= end) then
- R(i) == start + i*inc
- else
- R(i) == start - i*inc
!*/
const matrix_exp range (
......@@ -334,6 +340,27 @@ namespace dlib
R(i) == m(row,i)
!*/
// ----------------------------------------------------------------------------------------
const matrix_exp rowm (
const matrix_exp& m,
const matrix_exp& rows
);
/*!
requires
- rows contains elements of type long
- 0 <= min(rows) && max(rows) < m.nr()
- rows.nr() == 1 || rows.nc() == 1
(i.e. rows must be a vector)
ensures
- returns a matrix R such that:
- R::type == the same type that was in m
- R.nr() == rows.size()
- R.nc() == m.nc()
- for all valid r and c:
R(r,c) == m(rows(r),c)
!*/
// ----------------------------------------------------------------------------------------
const matrix_exp colm (
......@@ -351,6 +378,27 @@ namespace dlib
R(i) == m(i,col)
!*/
// ----------------------------------------------------------------------------------------
const matrix_exp colm (
const matrix_exp& m,
const matrix_exp& cols
);
/*!
requires
- cols contains elements of type long
- 0 <= min(cols) && max(cols) < m.nc()
- cols.nr() == 1 || cols.nc() == 1
(i.e. cols must be a vector)
ensures
- returns a matrix R such that:
- R::type == the same type that was in m
- R.nr() == m.nr()
- R.nc() == cols.size()
- for all valid r and c:
R(r,c) == m(r,cols(c))
!*/
// ----------------------------------------------------------------------------------------
assignable_matrix_expression set_subm (
......@@ -448,6 +496,30 @@ namespace dlib
- rowm(m,row) == uniform_matrix<matrix::type>(1,nc,scalar_value).
!*/
// ----------------------------------------------------------------------------------------
assignable_matrix_expression set_rowm (
matrix& m,
const matrix_exp& rows
);
/*!
requires
- rows contains elements of type long
- 0 <= min(rows) && max(rows) < m.nr()
- rows.nr() == 1 || rows.nc() == 1
(i.e. rows must be a vector)
ensures
- statements of the following form:
- set_rowm(m,rows) = some_matrix;
result in it being the case that:
- rowm(m,rows) == some_matrix.
- statements of the following form:
- set_rowm(m,rows) = scalar_value;
result in it being the case that:
- rowm(m,rows) == uniform_matrix<matrix::type>(nr,nc,scalar_value).
!*/
// ----------------------------------------------------------------------------------------
assignable_matrix_expression set_colm (
......@@ -469,6 +541,30 @@ namespace dlib
- colm(m,col) == uniform_matrix<matrix::type>(nr,1,scalar_value).
!*/
// ----------------------------------------------------------------------------------------
assignable_matrix_expression set_colm (
matrix& m,
const matrix_exp& cols
);
/*!
requires
- cols contains elements of type long
- 0 <= min(cols) && max(cols) < m.nc()
- cols.nr() == 1 || cols.nc() == 1
(i.e. cols must be a vector)
ensures
- statements of the following form:
- set_colm(m,cols) = some_matrix;
result in it being the case that:
- colm(m,cols) == some_matrix.
- statements of the following form:
- set_colm(m,cols) = scalar_value;
result in it being the case that:
- colm(m,cols) == uniform_matrix<matrix::type>(nr,nc,scalar_value).
!*/
// ----------------------------------------------------------------------------------------
template <
......
......@@ -829,6 +829,76 @@ namespace
DLIB_CASSERT((trans(range<1,3,5>()).nr() == 1),"");
}
{
DLIB_CASSERT(range(5,0).nr() == 6,"");
DLIB_CASSERT(range(5,1).nr() == 5,"");
DLIB_CASSERT(range(5,0).nc() == 1,"");
DLIB_CASSERT(range(5,1).nc() == 1,"");
DLIB_CASSERT(trans(range(5,0)).nc() == 6,"");
DLIB_CASSERT(trans(range(5,1)).nc() == 5,"");
DLIB_CASSERT(trans(range(5,0)).nr() == 1,"");
DLIB_CASSERT(trans(range(5,1)).nr() == 1,"");
DLIB_CASSERT(range(5,2,0).nr() == 3,"");
DLIB_CASSERT(range(5,2,1).nr() == 3,"");
DLIB_CASSERT(range(5,2,0).nc() == 1,"");
DLIB_CASSERT(range(5,2,1).nc() == 1,"");
DLIB_CASSERT(trans(range(5,2,0)).nc() == 3,"");
DLIB_CASSERT(trans(range(5,2,1)).nc() == 3,"");
DLIB_CASSERT(trans(range(5,2,0)).nr() == 1,"");
DLIB_CASSERT(trans(range(5,2,1)).nr() == 1,"");
DLIB_CASSERT(range(6,3,0).nr() == 3,"");
DLIB_CASSERT(range(5,3,1).nr() == 2,"");
DLIB_CASSERT(range(5,3,0).nc() == 1,"");
DLIB_CASSERT(range(5,3,1).nc() == 1,"");
DLIB_CASSERT(trans(range(6,3,0)).nc() == 3,"");
DLIB_CASSERT(trans(range(5,3,1)).nc() == 2,"");
DLIB_CASSERT(trans(range(5,3,0)).nr() == 1,"");
DLIB_CASSERT(trans(range(5,3,1)).nr() == 1,"");
DLIB_CASSERT(range(5,9,1).nr() == 1,"");
DLIB_CASSERT(range(5,9,1).nc() == 1,"");
DLIB_CASSERT(range(0,0).nr() == 1,"");
DLIB_CASSERT(range(0,0).nc() == 1,"");
DLIB_CASSERT(range(1,1)(0) == 1,"");
DLIB_CASSERT(range(5,0)(0) == 5 && range(5,0)(1) == 4 && range(5,0)(5) == 0,"");
DLIB_CASSERT(range(5,2,1)(0) == 5 && range(5,2,1)(1) == 3 && range(5,2,1)(2) == 1,"");
DLIB_CASSERT((range<5,0>()(0) == 5 && range<5,0>()(1) == 4 && range<5,0>()(5) == 0),"");
DLIB_CASSERT((range<5,2,1>()(0) == 5 && range<5,2,1>()(1) == 3 && range<5,2,1>()(2) == 1),"");
DLIB_CASSERT((range<5,0>().nr() == 6),"");
DLIB_CASSERT((range<5,1>().nr() == 5),"");
DLIB_CASSERT((range<5,0>().nc() == 1),"");
DLIB_CASSERT((range<5,1>().nc() == 1),"");
DLIB_CASSERT((trans(range<5,0>()).nc() == 6),"");
DLIB_CASSERT((trans(range<5,1>()).nc() == 5),"");
DLIB_CASSERT((trans(range<5,0>()).nr() == 1),"");
DLIB_CASSERT((trans(range<5,1>()).nr() == 1),"");
DLIB_CASSERT((range<5,2,0>().nr() == 3),"");
DLIB_CASSERT((range<5,2,1>().nr() == 3),"");
DLIB_CASSERT((range<5,2,0>().nc() == 1),"");
DLIB_CASSERT((range<5,2,1>().nc() == 1),"");
DLIB_CASSERT((trans(range<5,2,0>()).nc() == 3),"");
DLIB_CASSERT((trans(range<5,2,1>()).nc() == 3),"");
DLIB_CASSERT((trans(range<5,2,0>()).nr() == 1),"");
DLIB_CASSERT((trans(range<5,2,1>()).nr() == 1),"");
DLIB_CASSERT((range<6,3,0>().nr() == 3),"");
DLIB_CASSERT((range<5,3,1>().nr() == 2),"");
DLIB_CASSERT((range<5,3,0>().nc() == 1),"");
DLIB_CASSERT((range<5,3,1>().nc() == 1),"");
DLIB_CASSERT((trans(range<6,3,0>()).nc() == 3),"");
DLIB_CASSERT((trans(range<5,3,1>()).nc() == 2),"");
DLIB_CASSERT((trans(range<5,3,0>()).nr() == 1),"");
DLIB_CASSERT((trans(range<5,3,1>()).nr() == 1),"");
}
{
matrix<double> m(4,3);
for (long r = 0; r < m.nr(); ++r)
......
......@@ -943,6 +943,36 @@ namespace
}
{
matrix<double> m(4,4), m2;
m = 1,2,3,4,
1,2,3,4,
4,6,8,10,
4,6,8,10;
m2 = m;
DLIB_CASSERT(colm(m,range(0,3)) == m,"");
DLIB_CASSERT(rowm(m,range(0,3)) == m,"");
DLIB_CASSERT(colm(m,range(0,0)) == colm(m,0),"");
DLIB_CASSERT(rowm(m,range(0,0)) == rowm(m,0),"");
DLIB_CASSERT(colm(m,range(1,1)) == colm(m,1),"");
DLIB_CASSERT(rowm(m,range(1,1)) == rowm(m,1),"");
DLIB_CASSERT(colm(m,range(2,2)) == colm(m,2),"");
DLIB_CASSERT(rowm(m,range(2,2)) == rowm(m,2),"");
DLIB_CASSERT(colm(m,range(1,2)) == subm(m,0,1,4,2),"");
DLIB_CASSERT(rowm(m,range(1,2)) == subm(m,1,0,2,4),"");
set_colm(m,range(1,2)) = 9;
set_subm(m2,0,1,4,2) = 9;
DLIB_CASSERT(m == m2,"");
set_colm(m,range(1,2)) = 11;
set_subm(m2,0,1,4,2) = 11;
DLIB_CASSERT(m == m2,"");
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment