Commit 0f89c0f6 authored by Davis King's avatar Davis King

- Changed the range() stuff so that it works with ranges of the form

    range(1,5) as well as range(5,1)
  - Added overloads of colm(), rowm(), set_colm(), and set_rowm() that
    can work with ranges.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402722
parent 53ea098d
...@@ -1833,8 +1833,11 @@ convergence: ...@@ -1833,8 +1833,11 @@ convergence:
) )
{ {
start = start_; start = start_;
inc = 1; if (start_ <= end_)
nr_ = end_ - start_ + 1; inc = 1;
else
inc = -1;
nr_ = std::abs(end_ - start_) + 1;
} }
matrix_range_exp ( matrix_range_exp (
long start_, long start_,
...@@ -1843,8 +1846,11 @@ convergence: ...@@ -1843,8 +1846,11 @@ convergence:
) )
{ {
start = start_; start = start_;
inc = inc_; nr_ = std::abs(end_ - start_)/inc_ + 1;
nr_ = (end_ - start_)/inc_ + 1; if (start_ <= end_)
inc = inc_;
else
inc = -inc_;
} }
long operator() ( long operator() (
...@@ -1879,14 +1885,15 @@ convergence: ...@@ -1879,14 +1885,15 @@ convergence:
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <long start, long inc, long end> template <long start, long inc_, long end>
class matrix_range_static_exp class matrix_range_static_exp
{ {
public: public:
typedef long type; typedef long type;
typedef matrix_range_static_exp ref_type; typedef matrix_range_static_exp ref_type;
typedef memory_manager<char>::kernel_1a mem_manager_type; typedef memory_manager<char>::kernel_1a mem_manager_type;
const static long NR = (end - start)/inc + 1; const static long inc = (start <= end)?inc_:-inc_;
const static long NR = tabs<(end - start)>::value/inc_ + 1;
const static long NC = 1; const static long NC = 1;
long operator() ( long operator() (
...@@ -1920,7 +1927,7 @@ convergence: ...@@ -1920,7 +1927,7 @@ convergence:
const matrix_exp<matrix_range_static_exp<start,inc,end> > range ( const matrix_exp<matrix_range_static_exp<start,inc,end> > range (
) )
{ {
COMPILE_TIME_ASSERT(start <= end); COMPILE_TIME_ASSERT(inc > 0);
return matrix_exp<matrix_range_static_exp<start,inc,end> >(matrix_range_static_exp<start,inc,end>()); return matrix_exp<matrix_range_static_exp<start,inc,end> >(matrix_range_static_exp<start,inc,end>());
} }
...@@ -1928,7 +1935,6 @@ convergence: ...@@ -1928,7 +1935,6 @@ convergence:
const matrix_exp<matrix_range_static_exp<start,1,end> > range ( const matrix_exp<matrix_range_static_exp<start,1,end> > range (
) )
{ {
COMPILE_TIME_ASSERT(start <= end);
return matrix_exp<matrix_range_static_exp<start,1,end> >(matrix_range_static_exp<start,1,end>()); return matrix_exp<matrix_range_static_exp<start,1,end> >(matrix_range_static_exp<start,1,end>());
} }
...@@ -1937,13 +1943,6 @@ convergence: ...@@ -1937,13 +1943,6 @@ convergence:
long end long end
) )
{ {
DLIB_ASSERT(start <= end,
"\tconst matrix_exp range(start, end)"
<< "\n\tstart can't be bigger than end"
<< "\n\tstart: " << start
<< "\n\tend: " << end
);
return matrix_exp<matrix_range_exp>(matrix_range_exp(start,end)); return matrix_exp<matrix_range_exp>(matrix_range_exp(start,end));
} }
...@@ -1953,10 +1952,11 @@ convergence: ...@@ -1953,10 +1952,11 @@ convergence:
long end long end
) )
{ {
DLIB_ASSERT(start <= end, DLIB_ASSERT(inc > 0,
"\tconst matrix_exp range(start, inc, end)" "\tconst matrix_exp range(start, inc, end)"
<< "\n\tstart can't be bigger than end" << "\n\tstart can't be bigger than end"
<< "\n\tstart: " << start << "\n\tstart: " << start
<< "\n\tinc: " << inc
<< "\n\tend: " << end << "\n\tend: " << end
); );
...@@ -2250,6 +2250,56 @@ convergence: ...@@ -2250,6 +2250,56 @@ convergence:
return matrix_exp<exp>(exp(m,row)); return matrix_exp<exp>(exp(m,row));
} }
// ----------------------------------------------------------------------------------------
struct op_rowm_range
{
template <typename EXP1, typename EXP2>
struct op : has_destructive_aliasing
{
typedef typename EXP1::type type;
typedef typename EXP1::mem_manager_type mem_manager_type;
const static long NR = EXP2::NC*EXP2::NR;
const static long NC = EXP1::NC;
template <typename M1, typename M2>
static type apply ( const M1& m1, const M2& rows , long r, long c)
{ return m1(rows(r),c); }
template <typename M1, typename M2>
static long nr (const M1& m1, const M2& rows ) { return rows.size(); }
template <typename M1, typename M2>
static long nc (const M1& m1, const M2& ) { return m1.nc(); }
};
};
template <
typename EXP1,
typename EXP2
>
const matrix_exp<matrix_binary_exp<matrix_exp<EXP1>,matrix_exp<EXP2>,op_rowm_range> > rowm (
const matrix_exp<EXP1>& m,
const matrix_exp<EXP2>& rows
)
{
// the rows matrix must contain elements of type long
COMPILE_TIME_ASSERT((is_same_type<typename EXP2::type,long>::value == true));
DLIB_ASSERT(0 <= min(rows) && max(rows) < m.nr() && (rows.nr() == 1 || rows.nc() == 1),
"\tconst matrix_exp rowm(const matrix_exp& m, const matrix_exp& rows)"
<< "\n\tYou have given invalid arguments to this function"
<< "\n\tm.nr(): " << m.nr()
<< "\n\tm.nc(): " << m.nc()
<< "\n\tmin(rows): " << min(rows)
<< "\n\tmax(rows): " << max(rows)
<< "\n\trows.nr(): " << rows.nr()
<< "\n\trows.nc(): " << rows.nc()
);
typedef matrix_binary_exp<matrix_exp<EXP1>,matrix_exp<EXP2>,op_rowm_range> exp;
return matrix_exp<exp>(exp(m,rows));
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
struct op_colm struct op_colm
...@@ -2292,6 +2342,56 @@ convergence: ...@@ -2292,6 +2342,56 @@ convergence:
return matrix_exp<exp>(exp(m,col)); return matrix_exp<exp>(exp(m,col));
} }
// ----------------------------------------------------------------------------------------
struct op_colm_range
{
template <typename EXP1, typename EXP2>
struct op : has_destructive_aliasing
{
typedef typename EXP1::type type;
typedef typename EXP1::mem_manager_type mem_manager_type;
const static long NR = EXP1::NR;
const static long NC = EXP2::NC*EXP2::NR;
template <typename M1, typename M2>
static type apply ( const M1& m1, const M2& cols , long r, long c)
{ return m1(r,cols(c)); }
template <typename M1, typename M2>
static long nr (const M1& m1, const M2& cols ) { return m1.nr(); }
template <typename M1, typename M2>
static long nc (const M1& m1, const M2& cols ) { return cols.size(); }
};
};
template <
typename EXP1,
typename EXP2
>
const matrix_exp<matrix_binary_exp<matrix_exp<EXP1>,matrix_exp<EXP2>,op_colm_range> > colm (
const matrix_exp<EXP1>& m,
const matrix_exp<EXP2>& cols
)
{
// the cols matrix must contain elements of type long
COMPILE_TIME_ASSERT((is_same_type<typename EXP2::type,long>::value == true));
DLIB_ASSERT(0 <= min(cols) && max(cols) < m.nc() && (cols.nr() == 1 || cols.nc() == 1),
"\tconst matrix_exp colm(const matrix_exp& m, const matrix_exp& cols)"
<< "\n\tYou have given invalid arguments to this function"
<< "\n\tm.nr(): " << m.nr()
<< "\n\tm.nc(): " << m.nc()
<< "\n\tmin(cols): " << min(cols)
<< "\n\tmax(cols): " << max(cols)
<< "\n\tcols.nr(): " << cols.nr()
<< "\n\tcols.nc(): " << cols.nc()
);
typedef matrix_binary_exp<matrix_exp<EXP1>,matrix_exp<EXP2>,op_colm_range> exp;
return matrix_exp<exp>(exp(m,cols));
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
...@@ -2510,6 +2610,50 @@ convergence: ...@@ -2510,6 +2610,50 @@ convergence:
return assignable_sub_range_matrix<T,NR,NC,mm,l,matrix_exp<EXPr>,matrix_exp<EXPc> >(m,rows,cols); return assignable_sub_range_matrix<T,NR,NC,mm,l,matrix_exp<EXPr>,matrix_exp<EXPc> >(m,rows,cols);
} }
// ----------------------------------------------------------------------------------------
template <typename T, long NR, long NC, typename mm, typename l, typename EXPr>
assignable_sub_range_matrix<T,NR,NC,mm,l,matrix_exp<EXPr>,matrix_exp<matrix_range_exp> > set_rowm (
matrix<T,NR,NC,mm,l>& m,
const matrix_exp<EXPr>& rows
)
{
DLIB_ASSERT(0 <= min(rows) && max(rows) < m.nr() && (rows.nr() == 1 || rows.nc() == 1),
"\tassignable_matrix_expression set_rowm(matrix& m, const matrix_exp& rows)"
<< "\n\tYou have specified invalid sub matrix dimensions"
<< "\n\tm.nr(): " << m.nr()
<< "\n\tm.nc(): " << m.nc()
<< "\n\tmin(rows): " << min(rows)
<< "\n\tmax(rows): " << max(rows)
<< "\n\trows.nr(): " << rows.nr()
<< "\n\trows.nc(): " << rows.nc()
);
return assignable_sub_range_matrix<T,NR,NC,mm,l,matrix_exp<EXPr>,matrix_exp<matrix_range_exp> >(m,rows,range(0,m.nc()-1));
}
// ----------------------------------------------------------------------------------------
template <typename T, long NR, long NC, typename mm, typename l, typename EXPc>
assignable_sub_range_matrix<T,NR,NC,mm,l,matrix_exp<matrix_range_exp>,matrix_exp<EXPc> > set_colm (
matrix<T,NR,NC,mm,l>& m,
const matrix_exp<EXPc>& cols
)
{
DLIB_ASSERT(0 <= min(cols) && max(cols) < m.nc() && (cols.nr() == 1 || cols.nc() == 1),
"\tassignable_matrix_expression set_colm(matrix& m, const matrix_exp& cols)"
<< "\n\tYou have specified invalid sub matrix dimensions"
<< "\n\tm.nr(): " << m.nr()
<< "\n\tm.nc(): " << m.nc()
<< "\n\tmin(cols): " << min(cols)
<< "\n\tmax(cols): " << max(cols)
<< "\n\tcols.nr(): " << cols.nr()
<< "\n\tcols.nc(): " << cols.nc()
);
return assignable_sub_range_matrix<T,NR,NC,mm,l,matrix_exp<matrix_range_exp>,matrix_exp<EXPc> >(m,range(0,m.nr()-1),cols);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename T, long NR, long NC, typename mm, typename l> template <typename T, long NR, long NC, typename mm, typename l>
......
...@@ -218,13 +218,16 @@ namespace dlib ...@@ -218,13 +218,16 @@ namespace dlib
); );
/*! /*!
requires requires
- start <= end - inc > 0
ensures ensures
- returns a matrix R such that: - returns a matrix R such that:
- R::type == long - R::type == long
- R.nr() == (end - start)/inc + 1 - R.nr() == abs(end - start)/inc + 1
- R.nc() == 1 - R.nc() == 1
- R(i) == start + i*inc - if (start <= end) then
- R(i) == start + i*inc
- else
- R(i) == start - i*inc
!*/ !*/
template <long start, long end> template <long start, long end>
...@@ -238,13 +241,16 @@ namespace dlib ...@@ -238,13 +241,16 @@ namespace dlib
); );
/*! /*!
requires requires
- start <= end - inc > 0
ensures ensures
- returns a matrix R such that: - returns a matrix R such that:
- R::type == long - R::type == long
- R.nr() == (end - start)/inc + 1 - R.nr() == abs(end - start)/inc + 1
- R.nc() == 1 - R.nc() == 1
- R(i) == start + i*inc - if (start <= end) then
- R(i) == start + i*inc
- else
- R(i) == start - i*inc
!*/ !*/
const matrix_exp range ( const matrix_exp range (
...@@ -334,6 +340,27 @@ namespace dlib ...@@ -334,6 +340,27 @@ namespace dlib
R(i) == m(row,i) R(i) == m(row,i)
!*/ !*/
// ----------------------------------------------------------------------------------------
const matrix_exp rowm (
const matrix_exp& m,
const matrix_exp& rows
);
/*!
requires
- rows contains elements of type long
- 0 <= min(rows) && max(rows) < m.nr()
- rows.nr() == 1 || rows.nc() == 1
(i.e. rows must be a vector)
ensures
- returns a matrix R such that:
- R::type == the same type that was in m
- R.nr() == rows.size()
- R.nc() == m.nc()
- for all valid r and c:
R(r,c) == m(rows(r),c)
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
const matrix_exp colm ( const matrix_exp colm (
...@@ -351,6 +378,27 @@ namespace dlib ...@@ -351,6 +378,27 @@ namespace dlib
R(i) == m(i,col) R(i) == m(i,col)
!*/ !*/
// ----------------------------------------------------------------------------------------
const matrix_exp colm (
const matrix_exp& m,
const matrix_exp& cols
);
/*!
requires
- cols contains elements of type long
- 0 <= min(cols) && max(cols) < m.nc()
- cols.nr() == 1 || cols.nc() == 1
(i.e. cols must be a vector)
ensures
- returns a matrix R such that:
- R::type == the same type that was in m
- R.nr() == m.nr()
- R.nc() == cols.size()
- for all valid r and c:
R(r,c) == m(r,cols(c))
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
assignable_matrix_expression set_subm ( assignable_matrix_expression set_subm (
...@@ -448,6 +496,30 @@ namespace dlib ...@@ -448,6 +496,30 @@ namespace dlib
- rowm(m,row) == uniform_matrix<matrix::type>(1,nc,scalar_value). - rowm(m,row) == uniform_matrix<matrix::type>(1,nc,scalar_value).
!*/ !*/
// ----------------------------------------------------------------------------------------
assignable_matrix_expression set_rowm (
matrix& m,
const matrix_exp& rows
);
/*!
requires
- rows contains elements of type long
- 0 <= min(rows) && max(rows) < m.nr()
- rows.nr() == 1 || rows.nc() == 1
(i.e. rows must be a vector)
ensures
- statements of the following form:
- set_rowm(m,rows) = some_matrix;
result in it being the case that:
- rowm(m,rows) == some_matrix.
- statements of the following form:
- set_rowm(m,rows) = scalar_value;
result in it being the case that:
- rowm(m,rows) == uniform_matrix<matrix::type>(nr,nc,scalar_value).
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
assignable_matrix_expression set_colm ( assignable_matrix_expression set_colm (
...@@ -469,6 +541,30 @@ namespace dlib ...@@ -469,6 +541,30 @@ namespace dlib
- colm(m,col) == uniform_matrix<matrix::type>(nr,1,scalar_value). - colm(m,col) == uniform_matrix<matrix::type>(nr,1,scalar_value).
!*/ !*/
// ----------------------------------------------------------------------------------------
assignable_matrix_expression set_colm (
matrix& m,
const matrix_exp& cols
);
/*!
requires
- cols contains elements of type long
- 0 <= min(cols) && max(cols) < m.nc()
- cols.nr() == 1 || cols.nc() == 1
(i.e. cols must be a vector)
ensures
- statements of the following form:
- set_colm(m,cols) = some_matrix;
result in it being the case that:
- colm(m,cols) == some_matrix.
- statements of the following form:
- set_colm(m,cols) = scalar_value;
result in it being the case that:
- colm(m,cols) == uniform_matrix<matrix::type>(nr,nc,scalar_value).
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <
......
...@@ -829,6 +829,76 @@ namespace ...@@ -829,6 +829,76 @@ namespace
DLIB_CASSERT((trans(range<1,3,5>()).nr() == 1),""); DLIB_CASSERT((trans(range<1,3,5>()).nr() == 1),"");
} }
{
DLIB_CASSERT(range(5,0).nr() == 6,"");
DLIB_CASSERT(range(5,1).nr() == 5,"");
DLIB_CASSERT(range(5,0).nc() == 1,"");
DLIB_CASSERT(range(5,1).nc() == 1,"");
DLIB_CASSERT(trans(range(5,0)).nc() == 6,"");
DLIB_CASSERT(trans(range(5,1)).nc() == 5,"");
DLIB_CASSERT(trans(range(5,0)).nr() == 1,"");
DLIB_CASSERT(trans(range(5,1)).nr() == 1,"");
DLIB_CASSERT(range(5,2,0).nr() == 3,"");
DLIB_CASSERT(range(5,2,1).nr() == 3,"");
DLIB_CASSERT(range(5,2,0).nc() == 1,"");
DLIB_CASSERT(range(5,2,1).nc() == 1,"");
DLIB_CASSERT(trans(range(5,2,0)).nc() == 3,"");
DLIB_CASSERT(trans(range(5,2,1)).nc() == 3,"");
DLIB_CASSERT(trans(range(5,2,0)).nr() == 1,"");
DLIB_CASSERT(trans(range(5,2,1)).nr() == 1,"");
DLIB_CASSERT(range(6,3,0).nr() == 3,"");
DLIB_CASSERT(range(5,3,1).nr() == 2,"");
DLIB_CASSERT(range(5,3,0).nc() == 1,"");
DLIB_CASSERT(range(5,3,1).nc() == 1,"");
DLIB_CASSERT(trans(range(6,3,0)).nc() == 3,"");
DLIB_CASSERT(trans(range(5,3,1)).nc() == 2,"");
DLIB_CASSERT(trans(range(5,3,0)).nr() == 1,"");
DLIB_CASSERT(trans(range(5,3,1)).nr() == 1,"");
DLIB_CASSERT(range(5,9,1).nr() == 1,"");
DLIB_CASSERT(range(5,9,1).nc() == 1,"");
DLIB_CASSERT(range(0,0).nr() == 1,"");
DLIB_CASSERT(range(0,0).nc() == 1,"");
DLIB_CASSERT(range(1,1)(0) == 1,"");
DLIB_CASSERT(range(5,0)(0) == 5 && range(5,0)(1) == 4 && range(5,0)(5) == 0,"");
DLIB_CASSERT(range(5,2,1)(0) == 5 && range(5,2,1)(1) == 3 && range(5,2,1)(2) == 1,"");
DLIB_CASSERT((range<5,0>()(0) == 5 && range<5,0>()(1) == 4 && range<5,0>()(5) == 0),"");
DLIB_CASSERT((range<5,2,1>()(0) == 5 && range<5,2,1>()(1) == 3 && range<5,2,1>()(2) == 1),"");
DLIB_CASSERT((range<5,0>().nr() == 6),"");
DLIB_CASSERT((range<5,1>().nr() == 5),"");
DLIB_CASSERT((range<5,0>().nc() == 1),"");
DLIB_CASSERT((range<5,1>().nc() == 1),"");
DLIB_CASSERT((trans(range<5,0>()).nc() == 6),"");
DLIB_CASSERT((trans(range<5,1>()).nc() == 5),"");
DLIB_CASSERT((trans(range<5,0>()).nr() == 1),"");
DLIB_CASSERT((trans(range<5,1>()).nr() == 1),"");
DLIB_CASSERT((range<5,2,0>().nr() == 3),"");
DLIB_CASSERT((range<5,2,1>().nr() == 3),"");
DLIB_CASSERT((range<5,2,0>().nc() == 1),"");
DLIB_CASSERT((range<5,2,1>().nc() == 1),"");
DLIB_CASSERT((trans(range<5,2,0>()).nc() == 3),"");
DLIB_CASSERT((trans(range<5,2,1>()).nc() == 3),"");
DLIB_CASSERT((trans(range<5,2,0>()).nr() == 1),"");
DLIB_CASSERT((trans(range<5,2,1>()).nr() == 1),"");
DLIB_CASSERT((range<6,3,0>().nr() == 3),"");
DLIB_CASSERT((range<5,3,1>().nr() == 2),"");
DLIB_CASSERT((range<5,3,0>().nc() == 1),"");
DLIB_CASSERT((range<5,3,1>().nc() == 1),"");
DLIB_CASSERT((trans(range<6,3,0>()).nc() == 3),"");
DLIB_CASSERT((trans(range<5,3,1>()).nc() == 2),"");
DLIB_CASSERT((trans(range<5,3,0>()).nr() == 1),"");
DLIB_CASSERT((trans(range<5,3,1>()).nr() == 1),"");
}
{ {
matrix<double> m(4,3); matrix<double> m(4,3);
for (long r = 0; r < m.nr(); ++r) for (long r = 0; r < m.nr(); ++r)
......
...@@ -943,6 +943,36 @@ namespace ...@@ -943,6 +943,36 @@ namespace
} }
{
matrix<double> m(4,4), m2;
m = 1,2,3,4,
1,2,3,4,
4,6,8,10,
4,6,8,10;
m2 = m;
DLIB_CASSERT(colm(m,range(0,3)) == m,"");
DLIB_CASSERT(rowm(m,range(0,3)) == m,"");
DLIB_CASSERT(colm(m,range(0,0)) == colm(m,0),"");
DLIB_CASSERT(rowm(m,range(0,0)) == rowm(m,0),"");
DLIB_CASSERT(colm(m,range(1,1)) == colm(m,1),"");
DLIB_CASSERT(rowm(m,range(1,1)) == rowm(m,1),"");
DLIB_CASSERT(colm(m,range(2,2)) == colm(m,2),"");
DLIB_CASSERT(rowm(m,range(2,2)) == rowm(m,2),"");
DLIB_CASSERT(colm(m,range(1,2)) == subm(m,0,1,4,2),"");
DLIB_CASSERT(rowm(m,range(1,2)) == subm(m,1,0,2,4),"");
set_colm(m,range(1,2)) = 9;
set_subm(m2,0,1,4,2) = 9;
DLIB_CASSERT(m == m2,"");
set_colm(m,range(1,2)) = 11;
set_subm(m2,0,1,4,2) = 11;
DLIB_CASSERT(m == m2,"");
}
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment