Commit fa543c05 authored by Davis King's avatar Davis King

Added an overload of clamp() that lets you use matrix valued lower/upper

bounds.
parent 12138269
...@@ -2741,6 +2741,73 @@ namespace dlib ...@@ -2741,6 +2741,73 @@ namespace dlib
return matrix_op<op>(op(m.ref(),lower, upper)); return matrix_op<op>(op(m.ref(),lower, upper));
} }
// ----------------------------------------------------------------------------------------
template <typename M1, typename M2, typename M3>
struct op_clamp_m : basic_op_mmm<M1,M2,M3>
{
op_clamp_m( const M1& m1_, const M2& m2_, const M3& m3_) :
basic_op_mmm<M1,M2,M3>(m1_,m2_,m3_){}
typedef typename M1::type type;
typedef const typename M1::type const_ret_type;
const static long cost = M1::cost + M2::cost + M3::cost + 2;
const_ret_type apply (long r, long c) const
{
const type val = this->m1(r,c);
const type lower = this->m2(r,c);
const type upper = this->m3(r,c);
if (val <= upper)
{
if (lower <= val)
return val;
else
return lower;
}
else
{
return upper;
}
}
};
template <
typename EXP1,
typename EXP2,
typename EXP3
>
const matrix_op<op_clamp_m<EXP1,EXP2,EXP3> >
clamp (
const matrix_exp<EXP1>& m,
const matrix_exp<EXP2>& lower,
const matrix_exp<EXP3>& upper
)
{
COMPILE_TIME_ASSERT((is_same_type<typename EXP1::type,typename EXP2::type>::value == true));
COMPILE_TIME_ASSERT((is_same_type<typename EXP2::type,typename EXP3::type>::value == true));
COMPILE_TIME_ASSERT(EXP1::NR == EXP2::NR || EXP1::NR == 0 || EXP2::NR == 0);
COMPILE_TIME_ASSERT(EXP1::NC == EXP2::NC || EXP1::NR == 0 || EXP2::NC == 0);
COMPILE_TIME_ASSERT(EXP2::NR == EXP3::NR || EXP2::NR == 0 || EXP3::NR == 0);
COMPILE_TIME_ASSERT(EXP2::NC == EXP3::NC || EXP2::NC == 0 || EXP3::NC == 0);
DLIB_ASSERT(m.nr() == lower.nr() &&
m.nc() == lower.nc() &&
m.nr() == upper.nr() &&
m.nc() == upper.nc(),
"\tconst matrix_exp clamp(m,lower,upper)"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t m.nr(): " << m.nr()
<< "\n\t m.nc(): " << m.nc()
<< "\n\t lower.nr(): " << lower.nr()
<< "\n\t lower.nc(): " << lower.nc()
<< "\n\t upper.nr(): " << upper.nr()
<< "\n\t upper.nc(): " << upper.nc()
);
typedef op_clamp_m<EXP1,EXP2,EXP3> op;
return matrix_op<op>(op(m.ref(),lower.ref(),upper.ref()));
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename M> template <typename M>
......
...@@ -1695,6 +1695,33 @@ namespace dlib ...@@ -1695,6 +1695,33 @@ namespace dlib
- R(r,c) == m(r,c) - R(r,c) == m(r,c)
!*/ !*/
// ----------------------------------------------------------------------------------------
const matrix_exp clamp (
const matrix_exp& m,
const matrix_exp& lower,
const matrix_exp& upper
);
/*!
requires
- m.nr() == lower.nr()
- m.nc() == lower.nc()
- m.nr() == upper.nr()
- m.nc() == upper.nc()
- m, lower, and upper all contain the same type of elements.
ensures
- returns a matrix R such that:
- R::type == the same type that was in m
- R has the same dimensions as m
- for all valid r and c:
- if (m(r,c) > upper(r,c)) then
- R(r,c) == upper(r,c)
- else if (m(r,c) < lower(r,c)) then
- R(r,c) == lower(r,c)
- else
- R(r,c) == m(r,c)
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
const matrix_exp lowerbound ( const matrix_exp lowerbound (
......
...@@ -369,6 +369,33 @@ namespace ...@@ -369,6 +369,33 @@ namespace
DLIB_TEST(dm10 == m10); DLIB_TEST(dm10 == m10);
DLIB_TEST_MSG(sum(abs(sigmoid(dm10) -sigmoid(m10))) < 1e-10,sum(abs(sigmoid(dm10) -sigmoid(m10))) ); DLIB_TEST_MSG(sum(abs(sigmoid(dm10) -sigmoid(m10))) < 1e-10,sum(abs(sigmoid(dm10) -sigmoid(m10))) );
{
matrix<double,2,1> x, l, u, out;
x = 3,4;
l = 1,1;
u = 2,2.2;
out = 2, 2.2;
DLIB_TEST(equal(clamp(x, l, u) , out));
out = 3, 2.2;
DLIB_TEST(!equal(clamp(x, l, u) , out));
out = 2, 4.2;
DLIB_TEST(!equal(clamp(x, l, u) , out));
x = 1.5, 1.5;
out = x;
DLIB_TEST(equal(clamp(x, l, u) , out));
x = 0.5, 1.5;
out = 1, 1.5;
DLIB_TEST(equal(clamp(x, l, u) , out));
x = 1.5, 0.5;
out = 1.5, 1.0;
DLIB_TEST(equal(clamp(x, l, u) , out));
}
matrix<double, 7, 7,MM,column_major_layout> m7; matrix<double, 7, 7,MM,column_major_layout> m7;
matrix<double> dm7(7,7); matrix<double> dm7(7,7);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment