Commit c4942358 authored by Davis King's avatar Davis King

Made the interface to fft() and ifft() a little more flexible.

parent 6be28865
......@@ -34,10 +34,10 @@ namespace dlib
return temp;
}
template <typename T, long NR, long NC, typename MM, typename L>
template <typename EXP>
void permute (
const matrix<std::complex<T>,NR,NC,MM,L>& data,
matrix<std::complex<T>,NR,NC,MM,L>& outdata
const matrix_exp<EXP>& data,
typename EXP::matrix_type& outdata
)
{
outdata.set_size(data.size());
......@@ -67,14 +67,16 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <typename T, long NR, long NC, typename MM, typename L>
matrix<std::complex<T>,NR,NC,MM,L> fft (
const matrix<std::complex<T>,NR,NC,MM,L>& data
template <typename EXP>
typename EXP::matrix_type fft (
const matrix_exp<EXP>& data
)
{
if (data.size() == 0)
return data;
// You have to give a complex matrix
COMPILE_TIME_ASSERT(is_complex<typename EXP::type>::value);
// make sure requires clause is not broken
DLIB_CASSERT(is_vector(data) && is_power_of_two(data.size()),
"\t void ifft(data)"
......@@ -83,12 +85,14 @@ namespace dlib
<< "\n\t data.size(): " << data.size()
);
matrix<std::complex<T>,NR,NC,MM,L> outdata(data);
typedef typename EXP::type::value_type T;
typename EXP::matrix_type outdata(data);
const long half = outdata.size()/2;
typedef std::complex<T> ct;
matrix<ct,0,1,MM,L> twiddle_factors(half);
matrix<ct,0,1,typename EXP::mem_manager_type> twiddle_factors(half);
// compute the complex root of unity w
const T temp = -2.0*pi/outdata.size();
......@@ -126,21 +130,23 @@ namespace dlib
skip *= 2;
}
matrix<std::complex<T>,NR,NC,MM,L> outperm;
typename EXP::matrix_type outperm;
impl::permute(outdata, outperm);
return outperm;
}
// ----------------------------------------------------------------------------------------
template <typename T, long NR, long NC, typename MM, typename L>
matrix<std::complex<T>,NR,NC,MM,L> ifft (
const matrix<std::complex<T>,NR,NC,MM,L>& data
template <typename EXP>
typename EXP::matrix_type ifft (
const matrix_exp<EXP>& data
)
{
if (data.size() == 0)
return data;
// You have to give a complex matrix
COMPILE_TIME_ASSERT(is_complex<typename EXP::type>::value);
// make sure requires clause is not broken
DLIB_CASSERT(is_vector(data) && is_power_of_two(data.size()),
"\t void ifft(data)"
......@@ -150,13 +156,15 @@ namespace dlib
);
matrix<std::complex<T>,NR,NC,MM,L> outdata;
typedef typename EXP::type::value_type T;
typename EXP::matrix_type outdata;
impl::permute(data,outdata);
const long half = outdata.size()/2;
typedef std::complex<T> ct;
matrix<ct,0,1,MM,L> twiddle_factors(half);
matrix<ct,0,1,typename EXP::mem_manager_type> twiddle_factors(half);
// compute the complex root of unity w
const T temp = 2.0*pi/outdata.size();
......@@ -202,8 +210,9 @@ namespace dlib
#ifdef DLIB_USE_FFTW
inline matrix<std::complex<double>,0,1> fft(
const matrix<std::complex<double>,0,1>& data
template <long NR, long NC, typename MM, typename L>
matrix<std::complex<double>,NR,NC,MM,L> call_fftw_fft(
const matrix<std::complex<double>,NR,NC,MM,L>& data
)
{
// make sure requires clause is not broken
......@@ -214,7 +223,7 @@ namespace dlib
<< "\n\t data.size(): " << data.size()
);
matrix<std::complex<double>,0,1> m2(data.size());
matrix<std::complex<double>,NR,NC,MM,L> m2(data.size());
fftw_complex *in, *out;
fftw_plan p;
in = (fftw_complex*)&data(0);
......@@ -225,8 +234,9 @@ namespace dlib
return m2;
}
inline matrix<std::complex<double>,0,1> ifft(
const matrix<std::complex<double>,0,1>& data
template <long NR, long NC, typename MM, typename L>
matrix<std::complex<double>,NR,NC,MM,L> call_fftw_ifft(
const matrix<std::complex<double>,NR,NC,MM,L>& data
)
{
// make sure requires clause is not broken
......@@ -237,7 +247,7 @@ namespace dlib
<< "\n\t data.size(): " << data.size()
);
matrix<std::complex<double>,0,1> m2(data.size());
matrix<std::complex<double>,NR,NC,MM,L> m2(data.size());
fftw_complex *in, *out;
fftw_plan p;
in = (fftw_complex*)&data(0);
......@@ -248,6 +258,16 @@ namespace dlib
return m2/data.size();
}
// ----------------------------------------------------------------------------------------
// call FFTW for these cases:
inline matrix<std::complex<double>,0,1> fft (const matrix<std::complex<double>,0,1>& data) {return call_fftw_fft(data);}
inline matrix<std::complex<double>,0,1> ifft(const matrix<std::complex<double>,0,1>& data) {return call_fftw_ifft(data);}
inline matrix<std::complex<double>,1,0> fft (const matrix<std::complex<double>,1,0>& data) {return call_fftw_fft(data);}
inline matrix<std::complex<double>,1,0> ifft(const matrix<std::complex<double>,1,0>& data) {return call_fftw_ifft(data);}
inline matrix<std::complex<double> > fft (const matrix<std::complex<double> >& data) {return call_fftw_fft(data);}
inline matrix<std::complex<double> > ifft(const matrix<std::complex<double> >& data) {return call_fftw_ifft(data);}
#endif // DLIB_USE_FFTW
// ----------------------------------------------------------------------------------------
......
......@@ -22,18 +22,13 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename T,
long NR,
long NC,
typename MM,
typename L
>
matrix<std::complex<T>,NR,NC,MM,L> fft (
const matrix<std::complex<T>,NR,NC,MM,L>& data
template <typename EXP>
typename EXP::matrix_type fft (
const matrix_exp<EXP>& data
);
/*!
requires
- data contains elements of type std::complex<>
- is_vector(data) == true
- is_power_of_two(data.size()) == true
ensures
......@@ -53,18 +48,13 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template <
typename T,
long NR,
long NC,
typename MM,
typename L
>
matrix<std::complex<T>,NR,NC,MM,L> ifft (
const matrix<std::complex<T>,NR,NC,MM,L>& data
template <typename EXP>
typename EXP::matrix_type ifft (
const matrix_exp<EXP>& data
);
/*!
requires
- data contains elements of type std::complex<>
- is_vector(data) == true
- is_power_of_two(data.size()) == true
ensures
......
......@@ -77,6 +77,23 @@ namespace
}
}
// ----------------------------------------------------------------------------------------
void test_random_real_ffts()
{
print_spinner();
for (int iter = 0; iter < 10; ++iter)
{
for (int size = 1; size <= 64; size *= 2)
{
const matrix<complex<double>,0,1> m1 = complex_matrix(real(rand_complex(size)));
const matrix<complex<float>,0,1> fm1 = matrix_cast<complex<float> >(complex_matrix(real(rand_complex(size))));
DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(m1))))-m1)) < 1e-16);
}
}
}
// ----------------------------------------------------------------------------------------
class test_fft : public tester
......@@ -93,6 +110,7 @@ namespace
{
test_against_saved_good_ffts();
test_random_ffts();
test_random_real_ffts();
}
} a;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment