@skyword
2016-10-05T01:01:00.000000Z
字数 10475
阅读 2372
数学
数论
看了两天FFT
一直没有学fft,直到10天前打beijing online contest时,碰到了一个NTT的裸题,发现必须要看看了
hexo博客配置有一些问题,mathjax公式一直不能正常加载,可以访问我在cmdmarkdown上发布的版本 here
这几天看了一些博客,发现它纯粹在数学上需要功底。FFT,NTT的题目中,简单的那部分基本可以分为两个方面:建模列出表达式+套版
理解模板中那些纯粹属于FFT,NTT的做法并不难。难的是怎么用到题目中,这需要数学功底。。
我就不在这里从头写了,网上已经有了很好的博客来讲解FFT
首先,我决定尝试叙述一下逻辑关系:
我们现在要解决的,是高效的计算多项式乘法,传统朴素的方法是 的。
而FFT,即快速傅里叶变换,通过某种技巧,使复杂度降到 , 但是FFT的弱点在于,它的计算在复数域内进行,因此存在精度问题。
于是有了NTT,即快速数论变换。这一变换是对FFT的改进,使得所有计算在,即模p剩余系下考虑。因此,整数意义下的FFT问题可以用NTT来解决,从而避免精度误差。
FFT
贴出一些文章:
这两篇文章说的很清楚。傅里叶变换建立在一些基础的处理手法和引理上。
首先是多项式的点值表达,很好理解,一个n次一元多项式,与n个不同的点是相互确定的。这个可以从代数的角度上来理解。文章的推导也阐明了这点。
点值表达给我们带来的是,我们想要表述一个多项式,只需任取n个不同的点,就能唯一地表述它。我们想要求一个多项式,只要设法找到n个不同的该多项式经过的点,就可以求出其系数表达。
因此,我们接下来的工作就放在点值上来考虑。经验表明,只有真正想清楚了点值表达是什么以及它对下面的工作起到什么作用,才能更好地理解下面的推导。
教程里已经推导了,单纯的去取n个点,来作点值表达,是可以完成多项式相乘计算的。然而,普遍的时间复杂度还是 的。
到这时候,FFT才真正登场。
FFT的做法是,同样基于点值表达,但是选了一组很巧妙的点,即重单位复根
上面的文章证明了,具有同n重单位复根一样的那几条性质,因此可以代替n重单位复根来做计算。
所以要写NTT,只需要改动取 的部分即可。
我们需要的常数是大质数p,和它的原根g,p一般是取费马素数的形式,即:
目前想得到的一些自己对FFT,NTT的理解就是这样了,代码中有些细节的地方还需要我仔细琢磨,并没有搞得很透彻。
几道题目:
FFT入门题:
hdu1402
大数乘法,n位的大整数,实际上可以看成n次的多项式。两个大数的相乘,就可以理解成这两个多项式的乘法。在用FFT计算好乘法之外,我们额外做的就是进位和顺序调整了。
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <cmath>
#include <vector>
#include <utility>
#include <string>
#include <queue>
//#define maxn 1024
#define LL long long
#define fp freopen("in.txt","r",stdin)
#define fpo freopen("out9.txt","w",stdout)
//#define judge
using namespace std;
const double eps = 1e-10;
const double pi = acos(-1.0);
const int maxn = 250050;
const int INF = 0x3f3f3f3f;
const double lg2 = log(2.0);
const LL MOD = 1e9+7;
#define MAX 5050
struct Complex
{
double re,im;
Complex(double re = 0.0, double im = 0.0)
{
this->re = re;
this->im = im;
}
Complex operator -(const Complex &elem) const
{
return Complex(this->re - elem.re, this->im - elem.im);
}
Complex operator +(const Complex &elem) const
{
return Complex(this->re + elem.re, this->im + elem.im);
}
Complex operator *(const Complex &elem) const
{
return Complex(this->re * elem.re - this->im * elem.im, this->im*elem.re + this->re*elem.im);
}
void val(double re = 0.0, double im = 0.0)
{
this->re = re;
this->im = im;
}
};
Complex A[maxn],B[maxn];
int res[maxn],len,multilen,len1,len2;
char s1[maxn],s2[maxn];
void swap(Complex &a, Complex &b)
{
Complex tmp = a;
a = b, b=tmp;
}
void init()
{
len1 = strlen(s1),len2 = strlen(s2);
multilen = max(len1,len2);
len = 1;
while(len<(multilen<<1)) len<<=1;
for(int i = 0; i <len1; i++)
A[i].val(s1[len1-i-1]-'0',0);
for(int i = 0; i < len2; i++)
B[i].val(s2[len2-i-1]-'0',0);
for(int i = len1; i < len; i++)A[i].val();
for(int i = len2; i < len; i++)B[i].val();
}
void rader(Complex y[])
{
for(int i = 1,j=len>>1,k; i < len-1; i++)
{
if(i<j) swap(y[i],y[j]);
k = len>>1;
while(j >= k)
{
j-=k;
k>>=1;
}
if(j<k)j+=k;
}
}
//op==1 DFT
//op==-1 IDFT
void FFT(Complex y[], int op)
{
rader(y);
for(int h = 2; h <=len; h<<=1)
{
Complex wn(cos(op*2*pi/h),sin(op*2*pi/h));
for(int i = 0; i < len; i +=h)
{
Complex w(1,0);
for(int j = i; j < i+h/2; j++)
{
Complex u = y[j];
Complex t = w * y[j + h/2];
y[j] = u + t;
y[j+h/2] = u - t;
w = w *wn;
}
}
}
if(op==-1)// for IDFT
{
for(int i = 0; i < len ; i++)
{
y[i].re/=len;
}
}
}
void convolution(Complex *A, Complex *B)
{
FFT(A,1),FFT(B,1);
for(int i = 0; i <len; i++)
{
A[i] = A[i] * B[i];
}
FFT(A,-1);
for(int i = 0 ; i < len; i++)
res[i] = (int)(A[i].re +0.5);
}
void adjust(int *arr)
{
for(int i = 0; i < len ; i++)
{
res[i+1]+=res[i]/10;
res[i]%=10;
}
while(--len && res[len]==0);
}
void print(int *arr)
{
for(int i = len; i>=0;i--)
{
printf("%c",arr[i]+'0');
}
printf("\n");
}
int main()
{
while(gets(s1)&&gets(s2))
{
init();
convolution(A,B);
adjust(res);
print(res);
}
}
51nod上有一道同样的题目,好像规模更大,我用fft和ntt做了两次,有趣的是NTT跑的比FFT快了不少。
hdu4609
n条线段,有各自的长度,任取三条,求能组成三角形的组合数目。
这个题就不那么裸了,用到FFT的地方是,用num[]数组来存任意取两条线段能组成长度为i的方案数,基于两边之和大于第三边的原则,枚举三角形中的最长边来计数。然后去掉几类重复的情况,个人觉得这是更考察功底的地方。。kuangbin的题解
/*************************************************************************
> File Name: hdu4609.cpp
> Author: skyword
> Mail: skywordsun@gmail.com
> Created Time: 2016年10月02日 星期日 11时11分32秒
************************************************************************/
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <vector>
#include <set>
#include <algorithm>
#include <queue>
using namespace std;
const double pi = acos(-1.0);
#define maxn 400050
#define ll long long
#define fp freopen("in.txt","r",stdin);
struct Complex
{
double re,im;
Complex(double re=0.0, double im = 0.0)
{
this->re = re; this->im = im;
}
Complex operator +(const Complex &b)
{
return Complex(this->re + b.re, this->im + b.im);
}
Complex operator -(const Complex &b)
{
return Complex(this->re - b.re , this->im - b.im);
}
Complex operator *(const Complex &b)
{
return Complex(this->re*b.re - this->im*b.im , this->re*b.im + this->im*b.re);
}
};
void rader(Complex y[],int len)
{
int i,j,k;
for(i = 1,j=len/2; i<len-1;i++)
{
if(i<j) swap(y[i],y[j]);
k = len/2;
while(j >= k)
{
j-=k;
k/=2;
}
if(j<k) j += k;
}
}
void fft(Complex y[], int len , int op)
{
rader(y,len);
for(int h = 2; h <= len; h<<=1)
{
Complex wn(cos(-op*2*pi/h),sin(-op*2*pi/h));
for(int j = 0; j < len; j+=h)
{
Complex w(1,0);
for(int k = j; k < j+h/2; k++)
{
Complex u = y[k];
Complex t = w*y[k+h/2];
y[k] = u+t;
y[k+h/2] = u-t;//butterfly op
w = w* wn;
}
}
}
if(op == -1)
{
for(int i =0;i<len;i++)
{
y[i].re /= len;
}
}
}
Complex x[maxn];
int a[maxn];
ll num[maxn],sum[maxn];
int t,n;
int main()
{
scanf("%d",&t);
while(t--)
{
scanf("%d",&n);
memset(num,0,sizeof(num));
for(int i = 0; i<n;i++)
{
scanf("%d",&a[i]);
num[a[i]]++;
}
sort(a,a+n);
int len1 = a[n-1]+1;
int len = 1; // multilength
while(len < 2*len1) len <<= 1;
//init
for(int i = 0; i<len1 ; i++) x[i] = Complex(num[i],0);
for(int i = len1; i<len; i++) x[i] = Complex(0,0);
fft(x,len,1);
for(int i= 0; i<len;i++) x[i] = x[i] * x[i];
fft(x,len,-1);
for(int i = 0; i < len; i++)
{
num[i] = (ll)(x[i].re + 0.5);
}
len = 2*a[n-1];
for(int i =0; i<n;i++)
num[a[i]*2]--;
for(int i = 0; i <=len;i++) num[i]/=2;
sum[0] = 0;
for(int i = 1; i<=len;i++) sum[i] = sum[i-1] + num[i];
ll cnt = 0;
for(int i = 0; i<n;i++)
{
cnt += (sum[len]-sum[a[i]]);
cnt -= (ll)(n-1-i)*i;
cnt -= (n-1);
cnt -= (ll)(n-1-i)*(n-2-i)/2;
}
ll all = (ll)n*(n-1)*(n-2)/6;
double ans = (double)cnt/all;
printf("%.7lf\n",ans);
}
}
hdu5829
这是今年多校第八场的1009.
题意就叙述的很拗口。。我真的感觉有些出题人该好好润色一下英语表达了。。
题目给了一个n元素的数集,考虑任意的非空子集,考虑其前k大元素(如果子集本身元素数目小于k,就指的是其中全部元素)之和。对每一个给定的k,对所有子集,求sum的总和,即
懒得写了。。网上有题解。。
这个题用NTT做蛮好。更难的点是如何处理出卷积的形式。我自己琢磨了挺久的。。什么时候单独写个题解好了。。挺考验熟练度和功底
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <cmath>
#include <vector>
#include <utility>
#include <string>
#include <queue>
//#define maxn 1024
#define LL long long
#define fp freopen("in.txt","r",stdin)
#define fpo freopen("out9.txt","w",stdout)
//#define judge
using namespace std;
const double eps = 1e-10;
const double pi = acos(-1.0);
const int maxn = (1e5+20);
const int INF = 0x3f3f3f3f;
const int p = 998244353;
const int G = 3;
const double lg2 = log(2.0);
const LL MOD = 1e9+7;
#define MAX 5050
LL A[maxn<<2],B[maxn<<2];
LL quick_mod(LL a, LL b, LL m)
{
LL ans = 1;
while(b)
{
if(b&1) ans = ans * a % m;
a = a * a%m;
b >>= 1;
}
return ans;
}
void rader(LL y[], int len)
{
for(int i = 1,j=len/2; i < len-1; i++)
{
if(i < j) swap(y[i], y[j]);
int k = len/2;
while(j >= k)
{
j-=k;
k/=2;
}
if(j < k)j+=k;
}
}
void NTT(LL y[], int len , int op)
{
rader(y,len);
for(int h = 2; h <= len; h <<=1)
{
LL wn = quick_mod(G,(p-1)/h,p);
if(op == -1)
{
wn = quick_mod(wn,p-2,p);
}
// now wn is the rotation factor.
for(int j = 0; j <len; j+=h)
{
LL w = 1;
for(int k = j; k < j + h/2; k++)
{
LL u = y[k];
LL t = (w * y[k + h/2])%p;
y[k] = (u + t)%p;
y[k + h/2] = (u - t + p)%p;
w = w * wn % p;
}
}
}
// for IDFT(or maybe we call it IFNT)
if(op==-1)
{
LL inv = quick_mod(len , p-2, p);
for(int i = 0; i <len; i++)
y[i] = y[i] * inv % p;
}
}
int n,t,a[maxn],ans[maxn];
LL fac[maxn],tfac[maxn],inv_fac[maxn],inv_tfac[maxn];
void init()
{
fac[0] = tfac[0] = inv_fac[0] = inv_tfac[0] =1;
for(int i = 1; i < maxn ; i++)
{
fac[i] = fac[i-1] * i % p;
tfac[i] = 2 * tfac[i-1] %p;
inv_fac[i] = quick_mod(fac[i], p - 2, p);
inv_tfac[i] = quick_mod(tfac[i], p - 2, p);
}
}
int main()
{
init();
scanf("%d",&t);
while(t--)
{
scanf("%d",&n);
for(int i = 1; i <= n; i++) scanf("%d",&a[i]);
sort(a+1, a+1+n,greater<int>());
int len = 1;
while(len < ((n<<1)+1) ) len <<= 1;
for(int i = 0; i < len ; i++)
{
if(i <= n)
A[i] = tfac[n-i] * inv_fac[i] % p;
else A[i] = 0;
if(i <= n && i >= 1)
B[i] = a[i] * fac[i-1] % p;
else B[i] = 0;
}
reverse(B+1,B+1+n);
NTT(A,len,1);
NTT(B,len,1);
for(int i = 0; i<len;i++)
{
A[i] = A[i] * B[i] % p;
}
NTT(A, len ,-1);
for(int i = 1; i<=n; i++)
{
ans[i] =((inv_tfac[i] * inv_fac[i-1])%p) * A[n-i+1] %p;
}
for(int i = 1; i <=n;i++)
{
ans[i] = (ans[i] + ans[i-1]) % p;
}
for(int i = 1; i <=n;i++)
{
printf("%d ",ans[i]);
}
puts("");
}
}
hihocoder1388 : Periodic Signal
这题是今年北京网络赛的F题
经过处理可以知道,核心在于计算
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <cmath>
#include <vector>
#include <utility>
#include <string>
#include <queue>
//#define maxn 1024
#define LL long long
#define fp freopen("in.txt","r",stdin)
#define fpo freopen("out9.txt","w",stdout)
//#define judge
using namespace std;
const double eps = 1e-10;
const double pi = acos(-1.0);
const int maxn = 200020;
const int INF = 0x3f3f3f3f;
const LL p = 180143985094819841LL;
const int G = 3;
const double lg2 = log(2.0);
const LL MOD = 1e9+7;
#define MAX 5050
LL wn[20];
LL mul(LL x,LL y)
{
return (x*y-(LL)(x / (long double)p*y+1e-3)*p +p)%p;
}
LL quick_mod(LL a, LL b, LL m)
{
LL ans = 1;
while(b)
{
if(b&1) ans = mul(ans , a );
a = mul(a , a);
b >>= 1;
}
return ans;
}
void getwn()
{
for(int i = 1; i <=18;i++)
{
int t = 1<<i;
wn[i] = quick_mod(G,(p-1)/t,p);
}
}
void rader(LL y[], int len)
{
for(int i = 1,j=len/2; i < len-1; i++)
{
if(i < j) swap(y[i], y[j]);
int k = len/2;
while(j >= k)
{
j-=k;
k/=2;
}
if(j < k)j+=k;
}
}
void NTT(LL y[], int len , int op)
{
rader(y,len);
int id = 0;
for(int h = 2; h <= len; h <<=1)
{
id++;
// now wn is the rotation factor.
for(int j = 0; j <len; j+=h)
{
LL w = 1;
for(int k = j; k < j + h/2; k++)
{
LL u = y[k];
LL t = mul(y[k + h/2], w);
y[k] = (u + t)%p;
y[k + h/2] = (u - t + p)%p;
w = mul(w , wn[id]);
}
}
}
// for IDFT(or maybe we call it IFNT)
if(op==-1)
{
for(int i = 1; i < len/2; i++)
swap(y[i], y[len-i]);
LL inv = quick_mod(len , p-2, p);
for(int i = 0; i <len; i++)
y[i] = mul(y[i] , inv );
}
}
int t,n;
LL a[60060],b[60060];
LL A[maxn],B[maxn],C[maxn],ans;
LL sum = 0;
void init()
{
for(int i = 0; i < maxn;i++)
{
if(i<60060)
A[i]=B[i]=C[i]=a[i]=b[i]=0;
else A[i]=B[i]=C[i]=0;
}
}
int main()
{
getwn();
scanf("%d",&t);
while(t--)
{
init();
sum = 0;
scanf("%d",&n);
for(int i = 0; i < n; i++)
{
scanf("%lld",&a[i]);
sum += a[i]*a[i];
}
for(int i = 0; i < n; i++)
{
scanf("%lld",&b[i]);
sum += b[i]*b[i];
}
int len = 1;
while(len < (n<<1) ) len<<=1;
for(int i = 0; i < n ; i++)
{
A[i] = a[i];
}
for(int i = 0; i < n; i++)
{
B[i] = b[n-1-i];
}
NTT(A,len,1);
NTT(B,len,1);
for(int i = 0; i < len; i++)
{
C[i] = mul(A[i],B[i]);
}
NTT(C,len,-1);
ans = C[n-1];
for(int i = 0 ; i < n-2; i++)
{
ans = max(ans, C[i]+C[i+n]);
}
//cout<<"**"<<sum<<endl;
sum -= (2LL*ans);
printf("%lld\n",sum);
}
}
就先写这么多吧。。