Commit 1578e277 authored by Davis King's avatar Davis King

Fixed fft_inplace() not compiling for compile time sized matrices.

parent ef811cbd
......@@ -472,7 +472,7 @@ namespace dlib
// ----------------------------------------------------------------------------------------
template < typename T, long NR, long NC, typename MM, typename L >
void fft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
typename enable_if_c<NR==1||NC==1>::type fft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
// Note that we don't divide the outputs by data.size() so this isn't quite the inverse.
{
// make sure requires clause is not broken
......@@ -485,19 +485,31 @@ namespace dlib
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
);
if (data.nr() == 1 || data.nc() == 1)
{
impl::twiddles<T> cs;
impl::fft1d_inplace(data, false, cs);
}
else
{
impl::fft2d_inplace(data, false);
}
impl::twiddles<T> cs;
impl::fft1d_inplace(data, false, cs);
}
template < typename T, long NR, long NC, typename MM, typename L >
void ifft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
typename disable_if_c<NR==1||NC==1>::type fft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
// Note that we don't divide the outputs by data.size() so this isn't quite the inverse.
{
// make sure requires clause is not broken
DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
"\t void fft_inplace(data)"
<< "\n\t The number of rows and columns must be powers of two."
<< "\n\t data.nr(): "<< data.nr()
<< "\n\t data.nc(): "<< data.nc()
<< "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
);
impl::fft2d_inplace(data, false);
}
// ----------------------------------------------------------------------------------------
template < typename T, long NR, long NC, typename MM, typename L >
typename enable_if_c<NR==1||NC==1>::type ifft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
{
// make sure requires clause is not broken
DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
......@@ -509,15 +521,24 @@ namespace dlib
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
);
if (data.nr() == 1 || data.nc() == 1)
{
impl::twiddles<T> cs;
impl::fft1d_inplace(data, true, cs);
}
else
{
impl::fft2d_inplace(data, true);
}
impl::twiddles<T> cs;
impl::fft1d_inplace(data, true, cs);
}
template < typename T, long NR, long NC, typename MM, typename L >
typename disable_if_c<NR==1||NC==1>::type ifft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
{
// make sure requires clause is not broken
DLIB_CASSERT(is_power_of_two(data.nr()) && is_power_of_two(data.nc()),
"\t void ifft_inplace(data)"
<< "\n\t The number of rows and columns must be powers of two."
<< "\n\t data.nr(): "<< data.nr()
<< "\n\t data.nc(): "<< data.nc()
<< "\n\t is_power_of_two(data.nr()): " << is_power_of_two(data.nr())
<< "\n\t is_power_of_two(data.nc()): " << is_power_of_two(data.nc())
);
impl::fft2d_inplace(data, true);
}
// ----------------------------------------------------------------------------------------
......
......@@ -98,6 +98,28 @@ namespace
// ----------------------------------------------------------------------------------------
template <long nr, long nc>
void test_real_compile_time_sized_ffts()
{
print_spinner();
const matrix<complex<double>,nr,nc> m1 = complex_matrix(real(rand_complex(nr,nc)));
const matrix<complex<float>,nr,nc> fm1 = matrix_cast<complex<float> >(complex_matrix(real(rand_complex(nr,nc))));
DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(m1))))-m1)) < 1e-16);
DLIB_TEST(max(norm(ifft(fft(complex_matrix(real(fm1))))-fm1)) < 1e-7);
matrix<complex<double>,nr,nc> temp = m1;
matrix<complex<float>,nr,nc> ftemp = fm1;
fft_inplace(temp);
fft_inplace(ftemp);
DLIB_TEST(max(norm(temp-fft(m1))) < 1e-16);
DLIB_TEST(max(norm(ftemp-fft(fm1))) < 1e-7);
ifft_inplace(temp);
ifft_inplace(ftemp);
DLIB_TEST(max(norm(temp/temp.size()-m1)) < 1e-16);
DLIB_TEST(max(norm(ftemp/ftemp.size()-fm1)) < 1e-7);
}
void test_random_real_ffts()
{
for (int iter = 0; iter < 10; ++iter)
......@@ -126,6 +148,10 @@ namespace
}
}
}
test_real_compile_time_sized_ffts<16,16>();
test_real_compile_time_sized_ffts<16,1>();
test_real_compile_time_sized_ffts<1,16>();
}
// ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment