Commit c4942358 authored by Davis King's avatar Davis King

Made the interface to fft() and ifft() a little more flexible.

parent 6be28865
...@@ -34,10 +34,10 @@ namespace dlib ...@@ -34,10 +34,10 @@ namespace dlib
return temp; return temp;
} }
template <typename T, long NR, long NC, typename MM, typename L> template <typename EXP>
void permute ( void permute (
const matrix<std::complex<T>,NR,NC,MM,L>& data, const matrix_exp<EXP>& data,
matrix<std::complex<T>,NR,NC,MM,L>& outdata typename EXP::matrix_type& outdata
) )
{ {
outdata.set_size(data.size()); outdata.set_size(data.size());
...@@ -67,14 +67,16 @@ namespace dlib ...@@ -67,14 +67,16 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename T, long NR, long NC, typename MM, typename L> template <typename EXP>
matrix<std::complex<T>,NR,NC,MM,L> fft ( typename EXP::matrix_type fft (
const matrix<std::complex<T>,NR,NC,MM,L>& data const matrix_exp<EXP>& data
) )
{ {
if (data.size() == 0) if (data.size() == 0)
return data; return data;
// You have to give a complex matrix
COMPILE_TIME_ASSERT(is_complex<typename EXP::type>::value);
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_CASSERT(is_vector(data) && is_power_of_two(data.size()), DLIB_CASSERT(is_vector(data) && is_power_of_two(data.size()),
"\t void ifft(data)" "\t void ifft(data)"
...@@ -83,12 +85,14 @@ namespace dlib ...@@ -83,12 +85,14 @@ namespace dlib
<< "\n\t data.size(): " << data.size() << "\n\t data.size(): " << data.size()
); );
matrix<std::complex<T>,NR,NC,MM,L> outdata(data); typedef typename EXP::type::value_type T;
typename EXP::matrix_type outdata(data);
const long half = outdata.size()/2; const long half = outdata.size()/2;
typedef std::complex<T> ct; typedef std::complex<T> ct;
matrix<ct,0,1,MM,L> twiddle_factors(half); matrix<ct,0,1,typename EXP::mem_manager_type> twiddle_factors(half);
// compute the complex root of unity w // compute the complex root of unity w
const T temp = -2.0*pi/outdata.size(); const T temp = -2.0*pi/outdata.size();
...@@ -126,21 +130,23 @@ namespace dlib ...@@ -126,21 +130,23 @@ namespace dlib
skip *= 2; skip *= 2;
} }
matrix<std::complex<T>,NR,NC,MM,L> outperm; typename EXP::matrix_type outperm;
impl::permute(outdata, outperm); impl::permute(outdata, outperm);
return outperm; return outperm;
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <typename T, long NR, long NC, typename MM, typename L> template <typename EXP>
matrix<std::complex<T>,NR,NC,MM,L> ifft ( typename EXP::matrix_type ifft (
const matrix<std::complex<T>,NR,NC,MM,L>& data const matrix_exp<EXP>& data
) )
{ {
if (data.size() == 0) if (data.size() == 0)
return data; return data;
// You have to give a complex matrix
COMPILE_TIME_ASSERT(is_complex<typename EXP::type>::value);
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_CASSERT(is_vector(data) && is_power_of_two(data.size()), DLIB_CASSERT(is_vector(data) && is_power_of_two(data.size()),
"\t void ifft(data)" "\t void ifft(data)"
...@@ -150,13 +156,15 @@ namespace dlib ...@@ -150,13 +156,15 @@ namespace dlib
); );
matrix<std::complex<T>,NR,NC,MM,L> outdata; typedef typename EXP::type::value_type T;
typename EXP::matrix_type outdata;
impl::permute(data,outdata); impl::permute(data,outdata);
const long half = outdata.size()/2; const long half = outdata.size()/2;
typedef std::complex<T> ct; typedef std::complex<T> ct;
matrix<ct,0,1,MM,L> twiddle_factors(half); matrix<ct,0,1,typename EXP::mem_manager_type> twiddle_factors(half);
// compute the complex root of unity w // compute the complex root of unity w
const T temp = 2.0*pi/outdata.size(); const T temp = 2.0*pi/outdata.size();
...@@ -202,8 +210,9 @@ namespace dlib ...@@ -202,8 +210,9 @@ namespace dlib
#ifdef DLIB_USE_FFTW #ifdef DLIB_USE_FFTW
inline matrix<std::complex<double>,0,1> fft( template <long NR, long NC, typename MM, typename L>
const matrix<std::complex<double>,0,1>& data matrix<std::complex<double>,NR,NC,MM,L> call_fftw_fft(
const matrix<std::complex<double>,NR,NC,MM,L>& data
) )
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
...@@ -214,7 +223,7 @@ namespace dlib ...@@ -214,7 +223,7 @@ namespace dlib
<< "\n\t data.size(): " << data.size() << "\n\t data.size(): " << data.size()
); );
matrix<std::complex<double>,0,1> m2(data.size()); matrix<std::complex<double>,NR,NC,MM,L> m2(data.size());
fftw_complex *in, *out; fftw_complex *in, *out;
fftw_plan p; fftw_plan p;
in = (fftw_complex*)&data(0); in = (fftw_complex*)&data(0);
...@@ -225,8 +234,9 @@ namespace dlib ...@@ -225,8 +234,9 @@ namespace dlib
return m2; return m2;
} }
inline matrix<std::complex<double>,0,1> ifft( template <long NR, long NC, typename MM, typename L>
const matrix<std::complex<double>,0,1>& data matrix<std::complex<double>,NR,NC,MM,L> call_fftw_ifft(
const matrix<std::complex<double>,NR,NC,MM,L>& data
) )
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
...@@ -237,7 +247,7 @@ namespace dlib ...@@ -237,7 +247,7 @@ namespace dlib
<< "\n\t data.size(): " << data.size() << "\n\t data.size(): " << data.size()
); );
matrix<std::complex<double>,0,1> m2(data.size()); matrix<std::complex<double>,NR,NC,MM,L> m2(data.size());
fftw_complex *in, *out; fftw_complex *in, *out;
fftw_plan p; fftw_plan p;
in = (fftw_complex*)&data(0); in = (fftw_complex*)&data(0);
...@@ -248,6 +258,16 @@ namespace dlib ...@@ -248,6 +258,16 @@ namespace dlib
return m2/data.size(); return m2/data.size();
} }
// ----------------------------------------------------------------------------------------
// call FFTW for these cases:
inline matrix<std::complex<double>,0,1> fft (const matrix<std::complex<double>,0,1>& data) {return call_fftw_fft(data);}
inline matrix<std::complex<double>,0,1> ifft(const matrix<std::complex<double>,0,1>& data) {return call_fftw_ifft(data);}
inline matrix<std::complex<double>,1,0> fft (const matrix<std::complex<double>,1,0>& data) {return call_fftw_fft(data);}
inline matrix<std::complex<double>,1,0> ifft(const matrix<std::complex<double>,1,0>& data) {return call_fftw_ifft(data);}
inline matrix<std::complex<double> > fft (const matrix<std::complex<double> >& data) {return call_fftw_fft(data);}
inline matrix<std::complex<double> > ifft(const matrix<std::complex<double> >& data) {return call_fftw_ifft(data);}
#endif // DLIB_USE_FFTW #endif // DLIB_USE_FFTW
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -22,18 +22,13 @@ namespace dlib ...@@ -22,18 +22,13 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <typename EXP>
typename T, typename EXP::matrix_type fft (
long NR, const matrix_exp<EXP>& data
long NC,
typename MM,
typename L
>
matrix<std::complex<T>,NR,NC,MM,L> fft (
const matrix<std::complex<T>,NR,NC,MM,L>& data
); );
/*! /*!
requires requires
- data contains elements of type std::complex<>
- is_vector(data) == true - is_vector(data) == true
- is_power_of_two(data.size()) == true - is_power_of_two(data.size()) == true
ensures ensures
...@@ -53,18 +48,13 @@ namespace dlib ...@@ -53,18 +48,13 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < template <typename EXP>
typename T, typename EXP::matrix_type ifft (
long NR, const matrix_exp<EXP>& data
long NC,
typename MM,
typename L
>
matrix<std::complex<T>,NR,NC,MM,L> ifft (
const matrix<std::complex<T>,NR,NC,MM,L>& data
); );
/*! /*!
requires requires
- data contains elements of type std::complex<>
- is_vector(data) == true - is_vector(data) == true
- is_power_of_two(data.size()) == true - is_power_of_two(data.size()) == true
ensures ensures
......
...@@ -77,6 +77,23 @@ namespace ...@@ -77,6 +77,23 @@ namespace
} }
} }
// ----------------------------------------------------------------------------------------
void test_random_real_ffts()
{
print_spinner();
for (int iter = 0; iter < 10; ++iter)
{
for (int size = 1; size <= 64; size *= 2)
{
const matrix<complex<double>,0,1> m1 = complex_matrix(real(rand_complex(size)));
const matrix<complex<float>,0,1> fm1 = matrix_cast<complex<float> >(complex_matrix(real(rand_complex(size))));
DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(m1))))-m1)) < 1e-16);
}
}
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
class test_fft : public tester class test_fft : public tester
...@@ -93,6 +110,7 @@ namespace ...@@ -93,6 +110,7 @@ namespace
{ {
test_against_saved_good_ffts(); test_against_saved_good_ffts();
test_random_ffts(); test_random_ffts();
test_random_real_ffts();
} }
} a; } a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment