Commit 1578e277 authored by Davis King's avatar Davis King

Fixed fft_inplace() not compiling for compile time sized matrices.

parent ef811cbd
...@@ -472,7 +472,7 @@ namespace dlib ...@@ -472,7 +472,7 @@ namespace dlib
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template < typename T, long NR, long NC, typename MM, typename L > template < typename T, long NR, long NC, typename MM, typename L >
void fft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data) typename enable_if_c<NR==1||NC==1>::type fft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
// Note that we don't divide the outputs by data.size() so this isn't quite the inverse. // Note that we don't divide the outputs by data.size() so this isn't quite the inverse.
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
...@@ -485,19 +485,31 @@ namespace dlib ...@@ -485,19 +485,31 @@ namespace dlib
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc()) << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
); );
if (data.nr() == 1 || data.nc() == 1)
{
impl::twiddles<T> cs; impl::twiddles<T> cs;
impl::fft1d_inplace(data, false, cs); impl::fft1d_inplace(data, false, cs);
} }
else
template < typename T, long NR, long NC, typename MM, typename L >
typename disable_if_c<NR==1||NC==1>::type fft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
// Note that we don't divide the outputs by data.size() so this isn't quite the inverse.
{ {
// make sure requires clause is not broken
DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
"\t void fft_inplace(data)"
<< "\n\t The number of rows and columns must be powers of two."
<< "\n\t data.nr(): "<< data.nr()
<< "\n\t data.nc(): "<< data.nc()
<< "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
);
impl::fft2d_inplace(data, false); impl::fft2d_inplace(data, false);
} }
}
// ----------------------------------------------------------------------------------------
template < typename T, long NR, long NC, typename MM, typename L > template < typename T, long NR, long NC, typename MM, typename L >
void ifft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data) typename enable_if_c<NR==1||NC==1>::type ifft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
{ {
// make sure requires clause is not broken // make sure requires clause is not broken
DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()), DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
...@@ -509,16 +521,25 @@ namespace dlib ...@@ -509,16 +521,25 @@ namespace dlib
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc()) << "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
); );
if (data.nr() == 1 || data.nc() == 1)
{
impl::twiddles<T> cs; impl::twiddles<T> cs;
impl::fft1d_inplace(data, true, cs); impl::fft1d_inplace(data, true, cs);
} }
else
template < typename T, long NR, long NC, typename MM, typename L >
typename disable_if_c<NR==1||NC==1>::type ifft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
{ {
// make sure requires clause is not broken
DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
"\t void ifft_inplace(data)"
<< "\n\t The number of rows and columns must be powers of two."
<< "\n\t data.nr(): "<< data.nr()
<< "\n\t data.nc(): "<< data.nc()
<< "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
);
impl::fft2d_inplace(data, true); impl::fft2d_inplace(data, true);
} }
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
...@@ -98,6 +98,28 @@ namespace ...@@ -98,6 +98,28 @@ namespace
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
template <long nr, long nc>
void test_real_compile_time_sized_ffts()
{
print_spinner();
const matrix<complex<double>,nr,nc> m1 = complex_matrix(real(rand_complex(nr,nc)));
const matrix<complex<float>,nr,nc> fm1 = matrix_cast<complex<float> >(complex_matrix(real(rand_complex(nr,nc))));
DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(m1))))-m1)) < 1e-16);
DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(fm1))))-fm1)) < 1e-7);
matrix<complex<double>,nr,nc> temp = m1;
matrix<complex<float>,nr,nc> ftemp = fm1;
fft_inplace(temp);
fft_inplace(ftemp);
DLIB_TEST(max(norm(temp-fft(m1))) < 1e-16);
DLIB_TEST(max(norm(ftemp-fft(fm1))) < 1e-7);
ifft_inplace(temp);
ifft_inplace(ftemp);
DLIB_TEST(max(norm(temp/temp.size()-m1)) < 1e-16);
DLIB_TEST(max(norm(ftemp/ftemp.size()-fm1)) < 1e-7);
}
void test_random_real_ffts() void test_random_real_ffts()
{ {
for (int iter = 0; iter < 10; ++iter) for (int iter = 0; iter < 10; ++iter)
...@@ -126,6 +148,10 @@ namespace ...@@ -126,6 +148,10 @@ namespace
} }
} }
} }
test_real_compile_time_sized_ffts<16,16>();
test_real_compile_time_sized_ffts<16,1>();
test_real_compile_time_sized_ffts<1,16>();
} }
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment