Commit ccb1c95f authored by Davis King's avatar Davis King

Made dlib's built in fft faster by tweaking a few things and adding a twiddle

cache.
parent eab9604d
......@@ -95,27 +95,72 @@ namespace dlib
// ------------------------------------------------------------------------------------
template <typename T>
class twiddles
{
/*!
The point of this object is to cache the twiddle values so we don't
recompute them over and over inside R8TX().
!*/
public:
twiddles()
{
data.resize(64);
}
const std::complex<T>* get_twiddles (
int p
)
/*!
requires
- 0 <= p <= 64
ensures
- returns a pointer to the twiddle factors needed by R8TX if nxtlt == 2^p
!*/
{
// Compute the twiddle factors for this p value if we haven't done so
// already.
if (data[p].size() == 0)
{
const int nxtlt = 0x1 << p;
data[p].reserve(nxtlt*7);
const T twopi = 6.2831853071795865; /* 2.0 * pi */
const T scale = twopi/(nxtlt*8.0);
std::complex<T> cs[7];
for (int j = 0; j < nxtlt; ++j)
{
const T arg = j*scale;
cs[0] = std::complex<T>(std::cos(arg),std::sin(arg));
cs[1] = cs[0]*cs[0];
cs[2] = cs[1]*cs[0];
cs[3] = cs[1]*cs[1];
cs[4] = cs[2]*cs[1];
cs[5] = cs[2]*cs[2];
cs[6] = cs[3]*cs[2];
data[p].insert(data[p].end(), cs, cs+7);
}
}
return &data[p][0];
}
private:
std::vector<std::vector<std::complex<T> > > data;
};
// ----------------------------------------------------------------------------------------
/* Radix-8 iteration subroutine */
template <typename T>
void R8TX(int nxtlt, int nthpo, int length,
void R8TX(int nxtlt, int nthpo, int length, const std::complex<T>* cs,
std::complex<T> *cc0, std::complex<T> *cc1, std::complex<T> *cc2, std::complex<T> *cc3,
std::complex<T> *cc4, std::complex<T> *cc5, std::complex<T> *cc6, std::complex<T> *cc7)
{
const T irt2 = 0.707106781186548; /* 1.0/sqrt(2.0) */
const T twopi = 6.2831853071795865; /* 2.0 * pi */
const T scale = twopi/length;
for(int j=0; j<nxtlt; j++)
{
const T arg = j*scale;
const std::complex<T> cs1(std::cos(arg),std::sin(arg));
const std::complex<T> cs2 = cs1*cs1;
const std::complex<T> cs3 = cs2*cs1;
const std::complex<T> cs4 = cs2*cs2;
const std::complex<T> cs5 = cs3*cs2;
const std::complex<T> cs6 = cs3*cs3;
const std::complex<T> cs7 = cs4*cs3;
for(int k=j;k<nthpo;k+=length)
{
std::complex<T> a0, a1, a2, a3, a4, a5, a6, a7;
......@@ -144,18 +189,6 @@ namespace dlib
const std::complex<T> tmp2(-irt2*(b7.real()+b7.imag()), irt2*(b7.real()-b7.imag()));
cc0[k] = b0 + b1;
if(j>0)
{
cc1[k] = cs4*(b0-b1);
cc2[k] = cs2*(b2+tmp0);
cc3[k] = cs6*(b2-tmp0);
cc4[k] = cs1*(b4+tmp1);
cc5[k] = cs5*(b4-tmp1);
cc6[k] = cs3*(b6+tmp2);
cc7[k] = cs7*(b6-tmp2);
}
else
{
cc1[k] = b0 - b1;
cc2[k] = b2 + tmp0;
cc3[k] = b2 - tmp0;
......@@ -163,15 +196,26 @@ namespace dlib
cc5[k] = b4 - tmp1;
cc6[k] = b6 + tmp2;
cc7[k] = b6 - tmp2;
if(j>0)
{
cc1[k] *= cs[3];
cc2[k] *= cs[1];
cc3[k] *= cs[5];
cc4[k] *= cs[0];
cc5[k] *= cs[4];
cc6[k] *= cs[2];
cc7[k] *= cs[6];
}
}
cs += 7;
}
}
// ------------------------------------------------------------------------------------
template <typename T, long NR, long NC, typename MM, typename layout>
void fft1d_inplace(matrix<std::complex<T>,NR,NC,MM,layout>& data, bool do_backward_fft)
void fft1d_inplace(matrix<std::complex<T>,NR,NC,MM,layout>& data, bool do_backward_fft, twiddles<T>& cs)
/*!
requires
- is_vector(data) == true
......@@ -193,7 +237,7 @@ namespace dlib
std::complex<T>* const b = &data(0);
int L[16],L1,L2,L3,L4,L5,L6,L7,L8,L9,L10,L11,L12,L13,L14,L15;
int j1,j2,j3,j4,j5,j6,j7,j8,j9,j10,j11,j12,j13,j14;
int j, ij, ji, ij1, ji1;
int j, ij, ji;
int n2pow, n8pow, nthpo, ipass, nxtlt, length;
n2pow = fastlog2(data.size());
......@@ -206,9 +250,10 @@ namespace dlib
/* Radix 8 iterations */
for(ipass=1;ipass<=n8pow;ipass++)
{
nxtlt = 0x1 << (n2pow - 3*ipass);
const int p = n2pow - 3*ipass;
nxtlt = 0x1 << p;
length = 8*nxtlt;
R8TX(nxtlt, nthpo, length,
R8TX(nxtlt, nthpo, length, cs.get_twiddles(p),
b, b+nxtlt, b+2*nxtlt, b+3*nxtlt,
b+4*nxtlt, b+5*nxtlt, b+6*nxtlt, b+7*nxtlt);
}
......@@ -235,28 +280,26 @@ namespace dlib
L15=L[1];L14=L[2];L13=L[3];L12=L[4];L11=L[5];L10=L[6];L9=L[7];
L8=L[8];L7=L[9];L6=L[10];L5=L[11];L4=L[12];L3=L[13];L2=L[14];L1=L[15];
ij = 1;
for(j1=1;j1<=L1;j1++)
for(j2=j1;j2<=L2;j2+=L1)
for(j3=j2;j3<=L3;j3+=L2)
for(j4=j3;j4<=L4;j4+=L3)
for(j5=j4;j5<=L5;j5+=L4)
for(j6=j5;j6<=L6;j6+=L5)
for(j7=j6;j7<=L7;j7+=L6)
for(j8=j7;j8<=L8;j8+=L7)
for(j9=j8;j9<=L9;j9+=L8)
for(j10=j9;j10<=L10;j10+=L9)
for(j11=j10;j11<=L11;j11+=L10)
for(j12=j11;j12<=L12;j12+=L11)
for(j13=j12;j13<=L13;j13+=L12)
for(j14=j13;j14<=L14;j14+=L13)
for(ji=j14;ji<=L15;ji+=L14)
{
ij1 = ij-1;
ji1 = ji-1;
if(ij-ji<0)
swap(b[ij1], b[ji1]);
ij = 0;
for(j1=0;j1<L1;j1++)
for(j2=j1;j2<L2;j2+=L1)
for(j3=j2;j3<L3;j3+=L2)
for(j4=j3;j4<L4;j4+=L3)
for(j5=j4;j5<L5;j5+=L4)
for(j6=j5;j6<L6;j6+=L5)
for(j7=j6;j7<L7;j7+=L6)
for(j8=j7;j8<L8;j8+=L7)
for(j9=j8;j9<L9;j9+=L8)
for(j10=j9;j10<L10;j10+=L9)
for(j11=j10;j11<L11;j11+=L10)
for(j12=j11;j12<L12;j12+=L11)
for(j13=j12;j13<L13;j13+=L12)
for(j14=j13;j14<L14;j14+=L13)
for(ji=j14;ji<L15;ji+=L14)
{
if(ij<ji)
swap(b[ij], b[ji]);
ij++;
}
......@@ -283,12 +326,13 @@ namespace dlib
return;
matrix<std::complex<double> > buff;
twiddles<double> cs;
// Compute transform row by row
for(long r=0; r<data.nr(); ++r)
{
buff = matrix_cast<std::complex<double> >(rowm(data,r));
fft1d_inplace(buff, do_backward_fft);
fft1d_inplace(buff, do_backward_fft, cs);
set_rowm(data,r) = matrix_cast<std::complex<T> >(buff);
}
......@@ -296,7 +340,7 @@ namespace dlib
for(long c=0; c<data.nc(); ++c)
{
buff = matrix_cast<std::complex<double> >(colm(data,c));
fft1d_inplace(buff, do_backward_fft);
fft1d_inplace(buff, do_backward_fft, cs);
set_colm(data,c) = matrix_cast<std::complex<T> >(buff);
}
}
......@@ -328,12 +372,13 @@ namespace dlib
matrix<std::complex<double> > buff;
data_out.set_size(data.nr(), data.nc());
twiddles<double> cs;
// Compute transform row by row
for(long r=0; r<data.nr(); ++r)
{
buff = matrix_cast<std::complex<double> >(rowm(data,r));
fft1d_inplace(buff, do_backward_fft);
fft1d_inplace(buff, do_backward_fft, cs);
set_rowm(data_out,r) = matrix_cast<std::complex<T> >(buff);
}
......@@ -341,7 +386,7 @@ namespace dlib
for(long c=0; c<data_out.nc(); ++c)
{
buff = matrix_cast<std::complex<double> >(colm(data_out,c));
fft1d_inplace(buff, do_backward_fft);
fft1d_inplace(buff, do_backward_fft, cs);
set_colm(data_out,c) = matrix_cast<std::complex<T> >(buff);
}
}
......@@ -370,7 +415,8 @@ namespace dlib
if (data.nr() == 1 || data.nc() == 1)
{
matrix<typename EXP::type> temp(data);
impl::fft1d_inplace(temp, false);
impl::twiddles<typename EXP::type::value_type> cs;
impl::fft1d_inplace(temp, false, cs);
return temp;
}
else
......@@ -403,7 +449,8 @@ namespace dlib
if (data.nr() == 1 || data.nc() == 1)
{
temp = data;
impl::fft1d_inplace(temp, true);
impl::twiddles<typename EXP::type::value_type> cs;
impl::fft1d_inplace(temp, true, cs);
}
else
{
......@@ -430,10 +477,15 @@ namespace dlib
);
if (data.nr() == 1 || data.nc() == 1)
impl::fft1d_inplace(data, false);
{
impl::twiddles<T> cs;
impl::fft1d_inplace(data, false, cs);
}
else
{
impl::fft2d_inplace(data, false);
}
}
template < typename T, long NR, long NC, typename MM, typename L >
void ifft_inplace (matrix<std::complex<T>,NR,NC,MM,L>& data)
......@@ -449,10 +501,15 @@ namespace dlib
);
if (data.nr() == 1 || data.nc() == 1)
impl::fft1d_inplace(data, true);
{
impl::twiddles<T> cs;
impl::fft1d_inplace(data, true, cs);
}
else
{
impl::fft2d_inplace(data, true);
}
}
// ----------------------------------------------------------------------------------------
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment