Commit f36501d7 authored by Davis King's avatar Davis King

Updated the equal() function so that it can compare complex matrices.

I also changed a matrix test case to be more robust.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%403075
parent eeb0b806
...@@ -16,10 +16,32 @@ ...@@ -16,10 +16,32 @@
#include "matrix_expressions.h" #include "matrix_expressions.h"
namespace dlib namespace dlib
{ {
// ----------------------------------------------------------------------------------------
/*!A is_complex
This is a template that can be used to determine if a type is a specialization
of the std::complex template class.
For example:
is_complex<float>::value == false
is_complex<std::complex<float> >::value == true
!*/
template <typename T>
struct is_complex { static const bool value = false; };
template <typename T>
struct is_complex<std::complex<T> > { static const bool value = true; };
template <typename T>
struct is_complex<std::complex<T>& > { static const bool value = true; };
template <typename T>
struct is_complex<const std::complex<T>& > { static const bool value = true; };
template <typename T>
struct is_complex<const std::complex<T> > { static const bool value = true; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename EXP> template <typename EXP>
...@@ -1657,7 +1679,7 @@ namespace dlib ...@@ -1657,7 +1679,7 @@ namespace dlib
typename EXP1, typename EXP1,
typename EXP2 typename EXP2
> >
bool equal ( typename disable_if<is_complex<typename EXP1::type>,bool>::type equal (
const matrix_exp<EXP1>& a, const matrix_exp<EXP1>& a,
const matrix_exp<EXP2>& b, const matrix_exp<EXP2>& b,
const typename EXP1::type eps = 100*std::numeric_limits<typename EXP1::type>::epsilon() const typename EXP1::type eps = 100*std::numeric_limits<typename EXP1::type>::epsilon()
...@@ -1680,6 +1702,34 @@ namespace dlib ...@@ -1680,6 +1702,34 @@ namespace dlib
return true; return true;
} }
template <
typename EXP1,
typename EXP2
>
typename enable_if<is_complex<typename EXP1::type>,bool>::type equal (
const matrix_exp<EXP1>& a,
const matrix_exp<EXP2>& b,
const typename EXP1::type::value_type eps = 100*std::numeric_limits<typename EXP1::type::value_type>::epsilon()
)
{
// check if the dimensions don't match
if (a.nr() != b.nr() || a.nc() != b.nc())
return false;
for (long r = 0; r < a.nr(); ++r)
{
for (long c = 0; c < a.nc(); ++c)
{
if (std::abs(real(a(r,c)-b(r,c))) > eps ||
std::abs(imag(a(r,c)-b(r,c))) > eps)
return false;
}
}
// no non-equal points found so we return true
return true;
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
struct op_scale_columns struct op_scale_columns
......
...@@ -501,6 +501,7 @@ namespace dlib ...@@ -501,6 +501,7 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
// if matrix_exp contains non-complex types (e.g. float, double)
bool equal ( bool equal (
const matrix_exp& a, const matrix_exp& a,
const matrix_exp& b, const matrix_exp& b,
...@@ -516,6 +517,25 @@ namespace dlib ...@@ -516,6 +517,25 @@ namespace dlib
- returns true - returns true
!*/ !*/
// ----------------------------------------------------------------------------------------
// if matrix_exp contains std::complex types
bool equal (
const matrix_exp& a,
const matrix_exp& b,
const matrix_exp::type::value_type epsilon = 100*std::numeric_limits<matrix_exp::type::value_type>::epsilon()
);
/*!
ensures
- if (a and b don't have the same dimensions) then
- returns false
- else if (there exists an r and c such that abs(real(a(r,c)-b(r,c))) > epsilon
or abs(imag(a(r,c)-b(r,c))) > epsilon) then
- returns false
- else
- returns true
!*/
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
const matrix_exp pointwise_multiply ( const matrix_exp pointwise_multiply (
......
...@@ -857,7 +857,7 @@ namespace ...@@ -857,7 +857,7 @@ namespace
m = val1; m = val1;
m2 = val2; m2 = val2;
DLIB_TEST(reciprocal(m) == m2); DLIB_TEST(equal(reciprocal(m) , m2));
} }
{ {
matrix<complex<float> > m(2,2), m2(2,2); matrix<complex<float> > m(2,2), m2(2,2);
...@@ -865,7 +865,7 @@ namespace ...@@ -865,7 +865,7 @@ namespace
m = val1; m = val1;
m2 = val2; m2 = val2;
DLIB_TEST(reciprocal(m) == m2); DLIB_TEST(equal(reciprocal(m) , m2));
} }
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment