Commit 7206a12c authored by Davis King's avatar Davis King

- Added the tmax, tmin, and tabs templates

   - Changed the diag() function so that it is allowed to take
     non-square matrices.

--HG--
extra : convert_revision : svn%3Afdd8eb12-d10e-0410-9acb-85c331704f74/trunk%402697
parent cae2fb28
...@@ -568,6 +568,52 @@ namespace dlib ...@@ -568,6 +568,52 @@ namespace dlib
inline double put_in_range(const double& a, const double& b, const double& val) inline double put_in_range(const double& a, const double& b, const double& val)
{ return put_in_range<double>(a,b,val); } { return put_in_range<double>(a,b,val); }
// ----------------------------------------------------------------------------------------
/*! tabs
This is a template to compute the absolute value a number at compile time.
For example,
abs<-4>::value == 4
abs<4>::value == 4
!*/
template <long x, typename enabled=void>
struct tabs { const static long value = x; };
template <long x>
struct tabs<x,typename enable_if_c<(x < 0)>::type> { const static long value = -x; };
// ----------------------------------------------------------------------------------------
/*! tmax
This is a template to compute the max of two values at compile time
For example,
abs<4,7>::value == 7
!*/
template <long x, long y, typename enabled=void>
struct tmax { const static long value = x; };
template <long x, long y>
struct tmax<x,y,typename enable_if_c<(y > x)>::type> { const static long value = y; };
// ----------------------------------------------------------------------------------------
/*! tmin
This is a template to compute the min of two values at compile time
For example,
abs<4,7>::value == 4
!*/
template <long x, long y, typename enabled=void>
struct tmin { const static long value = x; };
template <long x, long y>
struct tmin<x,y,typename enable_if_c<(y < x)>::type> { const static long value = y; };
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
/*!A is_function /*!A is_function
......
...@@ -2948,7 +2948,7 @@ convergence: ...@@ -2948,7 +2948,7 @@ convergence:
template <typename EXP> template <typename EXP>
struct op : has_destructive_aliasing struct op : has_destructive_aliasing
{ {
const static long NR = EXP::NC; const static long NR = (EXP::NC&&EXP::NR)? (tmin<EXP::NR,EXP::NC>::value) : (0);
const static long NC = 1; const static long NC = 1;
typedef typename EXP::type type; typedef typename EXP::type type;
typedef typename EXP::mem_manager_type mem_manager_type; typedef typename EXP::mem_manager_type mem_manager_type;
...@@ -2957,7 +2957,7 @@ convergence: ...@@ -2957,7 +2957,7 @@ convergence:
{ return m(r,r); } { return m(r,r); }
template <typename M> template <typename M>
static long nr (const M& m) { return m.nr(); } static long nr (const M& m) { return std::min(m.nc(),m.nr()); }
template <typename M> template <typename M>
static long nc (const M& m) { return 1; } static long nc (const M& m) { return 1; }
}; };
...@@ -2970,14 +2970,6 @@ convergence: ...@@ -2970,14 +2970,6 @@ convergence:
const matrix_exp<EXP>& m const matrix_exp<EXP>& m
) )
{ {
// You can only get the diagonal for square matrices.
COMPILE_TIME_ASSERT(EXP::NR == EXP::NC);
DLIB_ASSERT(m.nr() == m.nc(),
"\tconst matrix_exp diag(const matrix_exp& m)"
<< "\n\tYou can only apply diag() to a square matrix"
<< "\n\tm.nr(): " << m.nr()
<< "\n\tm.nc(): " << m.nc()
);
typedef matrix_unary_exp<matrix_exp<EXP>,op_diag> exp; typedef matrix_unary_exp<matrix_exp<EXP>,op_diag> exp;
return matrix_exp<exp>(exp(m)); return matrix_exp<exp>(exp(m));
} }
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "../matrix.h" #include "../matrix.h"
#include "../rand.h" #include "../rand.h"
#include "../enable_if.h" #include "../enable_if.h"
#include "../algs.h"
#include "quantum_computing_abstract.h" #include "quantum_computing_abstract.h"
namespace dlib namespace dlib
...@@ -34,22 +35,6 @@ namespace dlib ...@@ -34,22 +35,6 @@ namespace dlib
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// This is a template to compute the absolute value a number at compile time
template <long x, typename enabled=void>
struct abs { const static long value = x; };
template <long x>
struct abs<x,typename enable_if_c<(x < 0)>::type> { const static long value = -x; };
// ------------------------------------------------------------------------------------
// This is a template to compute the max of two values at compile time
template <long x, long y, typename enabled=void>
struct max { const static long value = x; };
template <long x, long y>
struct max<x,y,typename enable_if_c<(y > x)>::type> { const static long value = y; };
// ------------------------------------------------------------------------------------
} }
typedef std::complex<double> qc_scalar_type; typedef std::complex<double> qc_scalar_type;
...@@ -657,7 +642,7 @@ namespace dlib ...@@ -657,7 +642,7 @@ namespace dlib
target_mask <<= 1; target_mask <<= 1;
} }
static const long num_bits = qc_helpers::abs<control_bit-target_bit>::value+1; static const long num_bits = tabs<control_bit-target_bit>::value+1;
static const long dims = qc_helpers::exp_2_n<num_bits>::value; static const long dims = qc_helpers::exp_2_n<num_bits>::value;
const qc_scalar_type operator() (long r, long c) const const qc_scalar_type operator() (long r, long c) const
...@@ -742,8 +727,8 @@ namespace dlib ...@@ -742,8 +727,8 @@ namespace dlib
target_mask <<= 1; target_mask <<= 1;
} }
static const long num_bits = qc_helpers::max<qc_helpers::abs<control_bit1-target_bit>::value, static const long num_bits = tmax<tabs<control_bit1-target_bit>::value,
qc_helpers::abs<control_bit2-target_bit>::value>::value+1; tabs<control_bit2-target_bit>::value>::value+1;
static const long dims = qc_helpers::exp_2_n<num_bits>::value; static const long dims = qc_helpers::exp_2_n<num_bits>::value;
const qc_scalar_type operator() (long r, long c) const const qc_scalar_type operator() (long r, long c) const
......
...@@ -120,6 +120,10 @@ namespace ...@@ -120,6 +120,10 @@ namespace
mrc.set_size(3,4); mrc.set_size(3,4);
set_all_elements(mrc,1); set_all_elements(mrc,1);
DLIB_CASSERT(diag(mrc) == uniform_matrix<double>(3,1,1),"");
DLIB_CASSERT(diag(matrix<double>(mrc)) == uniform_matrix<double>(3,1,1),"");
matrix<double,2,3> mrc2; matrix<double,2,3> mrc2;
set_all_elements(mrc2,1); set_all_elements(mrc2,1);
DLIB_CASSERT((removerc<1,1>(mrc) == mrc2),""); DLIB_CASSERT((removerc<1,1>(mrc) == mrc2),"");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment