快速傅里叶变换 FFT 易懂版
数学
多项式乘法
大数乘法
FFT 算法主要用于多项式相乘或者是大数的乘法计算。FFT 看了好久才懂,原来一直看不懂的原因是一直没理解多项式的点值表达到底是怎么来的,晕晕乎乎地抱着算导看了两天,看了那个第30章好多遍都没能看明白,直到突然看懂了点值表达,接下来看那些都很顺了,一下子全看懂了,接下来我以比较不抽象的方法讲讲 FFT,推荐使用暗色的背景看,这样代码高亮比较好看。
多项式乘法:
先讲讲点值表达吧,多项式的点值表达,说白了也很简单,比如说有一个多项式
A(x)=∑i=0n−1aixi
,我们在这里顺便假设一个列向量
a=(a0,a1,a2,...,an−1)
, 我们将
A(x) 看作是一个函数,然后在数轴上找到这些点
x0,x1,...xn−1
,然后将这些点分别代入
A(x),然后就可以得到
n 个点
{(x0,y0),(x1,y1),...,(xn−1,yn−1)}
,然后这
n 个点就是这个
A(x) 的点值表达。点值表达有什么用呢,虽然点值表达一点都没扯到
A(x) 的任意一个系数
ak,但是认真观察一下下面的式子,你就能想明白了,当初就是看懂了这个才顿悟的,
⎡⎣⎢⎢⎢⎢⎢11⋮1x0x1⋮xn−1x20x21⋮x2n−1......⋱...xn−10xn−11⋮xn−1n−1⎤⎦⎥⎥⎥⎥⎥⎡⎣⎢⎢⎢a0a1⋮an−1⎤⎦⎥⎥⎥=⎡⎣⎢⎢⎢⎢y0y1⋮yn−1⎤⎦⎥⎥⎥⎥
观察这个式子的时候你是否注意到,如果你能求出前面那个矩阵的逆矩阵,然后与结果矩阵相乘,就能得到系数的那个矩阵了,但是现在问题来了,这个矩阵是否一定有逆矩阵呢。这点我们还是可以证明的,前面这个矩阵其实称为范德蒙德矩阵,它的行列式的值为 ∏0≤j<k≤n−1(xk−xj),明显当这 n 个点都互不相同的时候,行列式的值不为零,所以逆矩阵存在。
顺便讲一个名词,次数界,比如 A(x) 的次数界为 m,那么 A(x) 的最高次小于 m。
然后问题来了,这 n的点是可以任意选的吗,答案是肯定的,但是得保证一点,这 n 个点得是不同的点,不然逆矩阵会不存在。虽然说是可以任意选,但是不同的选法会有不同的时间复杂度。然后神奇的是,选了合适的点后,傅里叶大发神威,一下子把时间复杂度从 Θ(n2) 降到了 Θ(nlgn)。原来的话,对于任意一点 xk,可以用秦九韶在 Θ(n) 的时间把 yk 计算出来,然后有 n 个点,所以总的复杂度是 Θ(n2),傅里叶的方法接下来慢慢讲。
总体计算思路:
先大体讲一下计算多项式 A(x)∗B(x) 的过程,然后慢慢分析每个过程:
- 首先取好相应的 n 个点,然后分别求出 A(x) 对应的 n 个点 (x0,y′0),(x1,y′1),...,(xn−1,y′n−1) 和 B(x) 对应的 n 个点 (x0,y′′0),(x1,y′′1),...,(xn−1,y′′n−1)。
假设 C(x)=A(x)∗B(x),且 C(x) 对应的 2n 个点为 (x0,y0),(x1,y1),...,(x2n−1,y2n−1),那么容易看出对于 k=0,1,2,...,2n−1,有 yk=y′k∗y′′k,这样我们就可以求出 C(x) 的点值表达。
最后只需要通过 C(x) 的点值表达便可以将 C(x) 各个项的系数都算出来了,这样结果就出来了。
以上讲得这些,其实都只是多项式相乘的内容,这里先对上面三点进行解释。首先,就算不使用 FFT 还是可以通过上面的方法计算出 C(x),只不过光第三步求逆矩阵就有 Θ(n3) 的时间复杂度,所以就不考虑了,直接计算两个多项式相乘也才 Θ(n2) 的时间复杂度。接着讲讲为什么第二步为什么 C(x) 要取 2n 个点,两个次数界为 n 的多项式相乘得到的多项式的最高次肯定会小于等于 (n−1)+(n−1),也就是 2n−2,我们不妨将 C(x) 的次数界看成 2n,对结果没有影响,而且可以保证是偶数。现在问题来了,A(x) 和 B(x) 各自只有 n 个点,要怎么生出 2n 个点呢,其实答案很简单,直接多选一些点就行了,想想也不会觉得难理解,就像是一个 n 元一次的方程组,你有 n 个不同的方程就能解出这 n 个未知数,但是你如果有 2n 个方程组,你只需要任选 n 个方程就能解出来了,这里也是同样的道理。第三部我们称它为插值。
接下来我们讲讲怎么用 FFT 怎么计算的吧。
预备知识:
首先,我们刚刚提到了选择特定的点后,可以运用傅里叶将算法复杂度降到 Θ(nlgn),这里特定的点就是单位复数根。我们先补充一些知识点。
单位复数根:
n 次单位复数根是满足 ωn=1 的复数 ω,我们将 n 次的单位复数根记为 ωn,例如 ω2 表示它是 2 次单位复数根,且 (ω2)2=ω22=1。n 次单位复数根的解恰好有 n 个:对于 k=0,1,...,n−1,这些根是 e2πkn,为了更好地理解这个单位复数根,我们可以看下这个公式:
eiθ=cosθ+i∗sinθ
首先,我们得先证明这 n 个点是互不相同的点,这个还是挺容易的,我们先看上面那个公式,如果将 θ 加上 2π,那么得到的复数还是同一个复数。观察下我们要代入的根,它们的 0≤θ<2π,且这 n 个点的 θ 是均匀地分配在这个区间内的。在复平面中,eθ1∗eθ2 的意义相当于 eθ1 这个复数逆时针旋转了 θ2 的角度。这里大体可以这样理解,对于一个复数 z1=aeθ1 和复数 z2=beθ2 相乘,其中 a,b,θ1,θ2 均为常数,z1z2=aeθ1z2=beθ2=abeθ1+θ2,所以复数相乘的意义就是扩大 b 倍,且逆时针旋转 θ2。
所以这 n 个单位复数根可以看成是将复平面均匀地分成了 n 个部分,每部分角度为 360∘/n,所以这 n 个单位复数根各不相同。
由于刚才提到的 eθ1∗eθ2 的意义,我们这边可以用 ω0n,ω1n,...,ωn−1n 这 n 个点表示我们所需要的 n 个点。
接下来讲讲一些所需要的定理和引理。
消去引理:
对于任意的整数 n≥0,k≥0 以及 d≥0,有
ωdkdn=ωkn
证明如下,
ωdkdn=(e2πidn)dk=(e2πin)k=ωkn
推论, 对于
∀n 满足
n>0 且
2|n,有
ωn/2n=ω2=−1
折半引理:
对于 ∀n 满足 n>0 且 2|n,有
(ωk+n/2n)2=(ωkn)2
证明如下,
(ωk+n/2n)2=ω2k+nn=ω2knωnn=ω2kn=(ωkn)2
也可这样证明,因为
ωn/2n=−1,所以
ωk+n/2n=−ωk,所以结论成立。
求和引理:
对于 ∀n 满足 n>0, k≥0 且 n∤k,有
∑i=0n−1(ωkn)i=0
证明如下,
等比数列求和公式同样适用于复数,所以有
∑i=0n−1(ωkn)i=(ωkn)n−1ωkn−1=(ωnn)k−1ωkn−1=(1)k−1ωkn−1=0
因为
n∤k,所以可以保证了分母不为零。
DFT:
对于多项式
A(x)=∑i=0n−1ai∗xi
,我们取好
n 个
n 次单位复数根
ω0n,ω1n,...,ωn−1n,代入多项式
A(x) 后,假设系数向量为
a=(a0,a1,...,an−1),结果向量
y=(y0,y1,...,yn−1),其中对于
k=0,1,...,n−1,
yk=A(ωkn)=∑i=0n−1aiωkin
,结果向量
y 就是系数向量
a 的离散傅里叶变换
(DFT),也记为
y=DFTn(a)。
FFT:
接下来就是重头戏了,我们可以通过快速傅里叶变换 (FFT),利用单位复数根的特殊性质,可以在 Θ(nlgn) 的时间内计算出 DFTn(a),而如果代入任意值计算 DFTn(a) 所需时间为 Θ(n2),接下来我们提到的 n 均假设为 2 的整数幂,如果遇到 n 不是 2 的整数幂的情况也没什么影响,我们只需要把 n 拓展成 2 的整数幂,多出来的直接补零就行了,对最后的结果没有影响,但是处理起来容易非常多。
FFT 采用了分治的思想,我们按照 A(x) 的奇数下标和偶数下标的系数定义两个新的次数界为 n/2 的多项式 A[0](x) 和 A[1](x):
A[0](x)=a0+a2x+a4x2+...+an−2xn/2−1A[1](x)=a1+a3x+a5x2+...+an−1xn/2−1
认真观察上面两个多项式,你会发现:
A(x)=A[0](x2)+xA[1](x2)
所以,求
A(x) 在
ω0n,ω1n,...,ωn−1n 处值得问题转化为了求次数界为
n/2 的多项式
A[0](x) 和
A[1](x) 在点
(ω0n)2,(ω1n)2,...,(ωn−1n)2
处的值。现在问题来了,这边有
n 个点,要求的多项式的次数界仅仅只有
n/2,我们应该如果选取点呢?其实这个问题完全不需要纠结,因为这里其实只有
n2 个不同的点,为什么呢,注意到
n 是 2 的整数幂,又有
ωnn=1,所以这
n 个点可以看成
ω0%nn,ω2%nn,...,ω(n/2−1)∗2%nn,ωn/2∗2%nn,ω(n/2+1)∗2%nn,...,ω(n−1)∗2%nn
所以其实这里只有
n/2 个点, 所以对于
k=0,1,2,...n/2−1,有
ω2kn=e4πkn=e2πk(n/2)=ωkn/2
所以这个递归的条件没有问题了。
有了递归方法也得有退出条件,显然,当多项式里只剩一个系数的时候是退出的情况,且 y0=a0ω01=a0∗1=a0。
接下来具体分析一下怎么计算出 A(x) 的所有点值,运用 FFT 计算点值的时候,我们只需要计算 k=0,1,2,...,n/2−1 的情况就行了。这里我们把 ωkn 称为旋转因子,接下来开始介绍怎么计算:
yk=y[0]k+y[1]k=A[0](ω2kn)+ωknA[1](ω2kn)=A(ωkn)
yk+n/2=A(ωk+n/2n)=A[0](ω2k+nn)+ωk+n/2nA[1](ω2k+nn)=A[0](ω2kn)−ωknA[1](ω2kn)=y[0]k−ωkny[1]k
所以,假设
u=y[0]k,
t=y[1]k,且当前旋转因子
ω=ωkn,我们只需要算出
u 和
t,那么
yk=u+ωt 和
yk+n/2=u−ωt,我们把上面的操作叫做蝴蝶操作。
通过以上步骤,相应的多项式的点值便能以
Θ(nlgn) 的时间计算出来。
这样还是不够快,我们还可以将递归改成非递归从而让程序更快些,而且每次递归的时候要生出新的数列,操作颇为繁琐。
接下来我们以
n=8 为例,下面是程序执行的时候递归调用时的树,不知道这边树要怎么画,只能画这么挫的表格来替代了。
(a0,a1,a2,a3,a4,a5,a6,a7) |
(a0,a2,a4,a6) |
(a1,a3,a5,a7) |
(a0,a4) |
(a2,a6) |
(a1,a5) |
(a3,a7) |
(a0) |
(a4) |
(a2) |
(a6) |
(a1) |
(a5) |
(a3) |
(a7) |
在图中,叶子出现的顺序是一个位逆序置换,这个肯定看不懂,下面解释下,叶子原本的序号为,
0 , 1 , 2 , 3 , 4 , 5 , 6 , 7
转化为二进制数就是
000, 001, 010, 011, 100, 101, 110, 111
位逆序置换后就是
000, 100, 010, 110, 001, 101, 011, 111
也就是
0 , 4 , 2 , 6 , 1 , 5 , 3 , 7
当 n 为其它 2 的整数幂时有同样的结论,意思就是说,
我们可以先将 A(x) 的系数的位置调整一下就可以直接线性地处理这个数列了。
这里有一个算法可以直接调整好位置,叫雷德算法,这里就不多说了,直接上一段代码
void Rader(Complex y[], int n)
{
for(int i = 1, j = n >> 1, k; i < n - 1; ++ i)
{
if(i < j)
swap(y[i], y[j]);
k = n >> 1;
while(j >= k)
{
j -= k;
k >>= 1;
}
if(j < k)
j += k;
}
}
用了上面的位逆序置换的雷德算法后,就可以直接计算 DFTn(a)了。
DFT−1:
计算完 A(x) 和 B(x) 的所有点值后,对于 k=0,1,2,...,2n−1,直接计算 yk=y′k∗y′′k。这样 C(x) 的点值表达在 Θ(n) 的时间内计算出来了。
我们不妨把 C(x) 的系数向量设为 a=(a0,a1,a2,...,an−1),所以我们只需要计算出下面的范德蒙德矩阵的逆矩阵,然后与结果向量 y=(y0,y1,y2,...,yn−1) 相乘,就能求出系数向量 a,求系数向量的过程我们称之为插值,
⎡⎣⎢⎢⎢⎢11⋮11ωn⋮ωn−1n1ω2n⋮ω2(n−1)n......⋱...1ωn−1n⋮ω(n−1)(n−1)n⎤⎦⎥⎥⎥⎥⎡⎣⎢⎢⎢a0a1⋮an−1⎤⎦⎥⎥⎥=⎡⎣⎢⎢⎢⎢y0y1⋮yn−1⎤⎦⎥⎥⎥⎥
我们将上面那个范德蒙德矩阵记得 Vn,直接叫你去算 V−1n 肯定不好算的,但是这边有现成的结论,
对于 j,k=0,1,2,...,n−1,V−1n 在 (j,k) 处的元素为 ω−kjnn。
证明倒是挺容易的,
我们将 n×n 的单位矩阵记为 In,考虑 V−1nVn 中 (j,j′) 位置的值,
[V−1nVn]jj′=∑k=0n−1(ω−kjnn)(ωkj′n)=1n∑k=0n−1ωk(j−j′)n
当
j=j′ 时,该值明显为 1,当
j≠j′ 时,
−(n−1)≤j−j′≤n−1,所以
n∤(j−j′),所以根据求和引理,该值为零。综上所述,
VnV−1n=In。
所以我们可以根据这个定理推出 DFT−1n(y),也就是通过结果向量计算系数向量,
ai=1n∑k=0n−1ykω−kin
其中
i=0,1,2,...,n−1。
认真观察上面这个式子对比一下前面讲到
DFT 时的式子,对于
k=0,1,...,n−1,有
yk=∑i=0n−1aiωkjn
我们运用
FFT 计算
yk 和这里计算
ai 方法是类似的,我们要做的不过就是用
a 替代
y,计算
ωkn 改成计算
(ωkn)−1,最后在值全部计算完后每个值都除以
n 即可。
最后的最后,上一段代码,里面注释解释具体的实现过程和一些细节。
题目是 hdu 1402
代码:
#include<cmath>
#include<cstdio>
#include<cstring>
const int MAXN = 2e5 + 5;
const double PI = acos(-1.0);
#define max(a, b) (a) > (b) ? (a) : (b)
class Complex
{
public:
double real, imag;
Complex(double real = 0.0, double imag = 0.0)
{
this->real = real, this->imag = imag;
}
Complex operator - (const Complex &elem) const
{
return Complex(this->real - elem.real, this->imag - elem.imag);
}
Complex operator + (const Complex &elem) const
{
return Complex(this->real + elem.real, this->imag + elem.imag);
}
Complex operator * (const Complex &elem) const
{
return Complex(this->real * elem.real - this->imag * elem.imag, this->real * elem.imag + this->imag * elem.real);
}
void setValue(double real = 0.0, double imag = 0.0)
{
this->real = real, this->imag = imag;
}
};
Complex A[MAXN], B[MAXN];
int res[MAXN], len, mlen, len1, len2;
char str1[MAXN >> 1], str2[MAXN >> 1];
void Swap(Complex &a, Complex &b)
{
Complex tmp = a;
a = b;
b = tmp;
}
void Prepare()
{
len1 = strlen(str1), len2 = strlen(str2);
mlen = max(len1, len2);
len = 1;
// 将 len 扩大到 2 的整数幂
while(len < (mlen << 1))
len <<= 1;
//初始化多项式的系数
for(int i = 0; i < len1; ++ i)
A[i].setValue(str1[len1 - i - 1] - '0', 0);
for(int i = 0; i < len2; ++ i)
B[i].setValue(str2[len2 - i - 1] - '0', 0);
// 补 0
for(int i = len1; i < len; ++ i)
A[i].setValue();
for(int i = len2; i < len; ++ i)
B[i].setValue();
}
//雷德算法 位逆序置换
void Rader(Complex y[])
{
for(int i = 1, j = len >> 1, k; i < len - 1; ++ i)
{
if(i < j)
Swap(y[i], y[j]);
k = len >> 1;
while(j >= k)
{
j -= k;
k >>= 1;
}
if(j < k)
j += k;
}
}
//DFT : op == 1
//IDFT : op == -1
void FFT(Complex y[], int op)
{
//先位逆序置换
Rader(y);
// h 为每次要处理的长度, h = 1 时不需处理
for(int h = 2; h <= len; h <<= 1)
{
// Wn = e^(2 * PI / n),如果是插值,那么 Wn = e^(-2 * PI / n)
Complex Wn(cos(op * 2 * PI / h), sin(op * 2 * PI / h));
for(int i = 0; i < len; i += h)
{
//旋转因子,初始化为 e^0
Complex W(1, 0);
for(int j = i; j < i + h / 2; ++ j)
{
Complex u = y[j];
Complex t = W * y[j + h / 2];
//蝴蝶操作
y[j] = u + t;
y[j + h / 2] = u - t;
//每次更新旋转因子
W = W * Wn;
}
}
}
// 插值的时候要除以 len
if(op == -1)
for(int i = 0; i < len; ++ i)
y[i].real /= len;
}
//DFT 后将 A 和 B 相应点值相乘,将结果放到 res 里面
void Convolution(Complex *A, Complex *B)
{
//evaluation
FFT(A, 1), FFT(B, 1);
for(int i = 0; i < len; ++ i)
A[i] = A[i] * B[i];
//interpolation
FFT(A, -1);
for(int i = 0; i < len; ++ i)
res[i] = (int)(A[i].real + 0.5);
}
void Adjustment(int *arr)
{
//次数界为 len,所以不用担心进位不会进到第 len 位
for(int i = 0; i < len; ++ i)
{
res[i + 1] += res[i] / 10;
res[i] %= 10;
}
//去除多余的 0
while(-- len && res[len] == 0);
}
void Display(int *arr)
{
for(int i = len; i >= 0; -- i)
putchar(arr[i] + '0');
putchar('\n');
}
int main()
{
while(gets(str1) && gets(str2))
{
Prepare();
Convolution(A, B);
Adjustment(res);
Display(res);
}
return 0;
}