@xunuo
2017-01-20T09:40:35.000000Z
字数 3406
阅读 956
Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 65536/32768 K (Java/Others)
FFT
来源:HDU 1402 A * B Problem Plus
Calculate A * B.
Each line will contain two integers A and B. Process to end of file.
Note: the length of each integer will not exceed 50000.
For each case, output A * B in one line.
1
2
1000
2
2
2000
题意:
输入两个数,输出他们的乘积
解题思路:
如果用高精度会超时......
用FFT或者NTT......
然而你啥都不知道~~呵呵哒
完整代码:
///这个是张晓雨学长的FFT模板
#include<stdio.h>
#include<math.h>
#include<complex>
#include<string.h>
#include<algorithm>
using namespace std;
#define maxn 400000
const double PI=acos(-1.0);
typedef complex <double> Complex;
void rader(Complex *y, int len)
{
for(int i=1,j=len/2;i<len-1;i++)
{
if(i<j)
swap(y[i],y[j]);
int k=len/2;
while(j>=k)
{
j-=k;
k/=2;
}
if(j<k)j+=k;
}
}
void fft(Complex *y,int len,int op)
{
rader(y,len);
for(int h=2;h<=len;h<<=1)
{
double ang=op*2*PI/h;
Complex wn(cos(ang),sin(ang));
for(int j=0;j<len;j+=h)
{
Complex w(1,0);
for(int k=j;k<j+h/2;k++)
{
Complex u=y[k];
Complex t=w*y[k+h/2];
y[k]=u+t;
y[k+h/2]=u-t;
w=w*wn;
}
}
}
if(op==-1)
for(int i=0;i<len;i++)
y[i]/=len;
}
Complex x1[maxn], x2[maxn];
char str1[maxn], str2[maxn];
int sum[maxn];
int main()
{
while(scanf("%s%s",str1,str2)!=EOF)
{
int len1=strlen(str1);
int len2=strlen(str2);
int len=1;
while(len<len1*2||len<len2*2)
len<<=1;
for(int i=0;i<len1;i++)
x1[i]=Complex(str1[len1-1-i]-'0',0);
for(int i=len1;i<len;i++)
x1[i]=Complex(0,0);
for(int i=0;i<len2;i++)
x2[i]=Complex(str2[len2-i-1]-'0',0);
for(int i=len2;i<len;i++)
x2[i]=Complex(0,0);
//DFT
fft(x1,len,1);
fft(x2,len,1);
for(int i=0;i<len;i++)
x1[i]=x1[i]*x2[i];
fft(x1,len,-1);
for(int i=0;i<len;i++)
sum[i]=(int)(x1[i].real()+0.5);
for(int i=0;i<len;i++)
{
sum[i+1]+=sum[i]/10;
sum[i]%=10;
}
len=len1+len2-1;
while(sum[len]<=0&&len>0)
len--;
for(int i=len;i>=0;i--)
printf("%c",sum[i]+'0');
printf("\n");
}
return 0;
}
这个是阳哥的NTT
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#define MAXN 400000
#define LL long long
const LL P = (479 << 21) + 1;//费马素数P
const LL G = 3;//P的原根 只有G的P-1次方等于1,G的i次方(0<=i<P-1)都不为1。
//P = (119 << 23) +1;G=3
//P = (1917 << 19)+1;G=5
//P = (453 << 21 )+1 G=7
#define NUM 20 //2的20次方大于MAXN
LL Wn[NUM];
using namespace std;
LL quickly_mod(LL a, LL n, LL mod)
{
LL res = 1, tmp = a;
while (n)
{
if (n & 1)
{
res *= tmp;
res %= mod;
}
tmp *= tmp;
tmp %= mod;
n >>= 1;
}
return res;
}
void getWn()
{
for (int i = 0; i < NUM; i++)
{
int t = 1 << i;
Wn[i] = quickly_mod(G, (P - 1) / t, P);
}
}
inline int turn(int n)
{
int i = 1;
for (; i < n; i <<= 1);
return i;
}
void build(LL _S[], LL S[], int n, int m, int curr, int &cnt)
{
if (m == n) _S[curr] = S[cnt++];
else
{
build(_S, S, n, m << 1, curr, cnt);
build(_S, S, n, m << 1, curr + m, cnt);
}
}
void NTT(LL S[], int n, int op)
{
static LL _S[MAXN];
int cnt = 0;
build(_S, S, n, 1, 0, cnt);
memcpy(S, _S, sizeof(LL)*n);
for (int len = 2, id = 1; len <= n; len <<= 1, id++)
{
int m = len >> 1;
LL unit = Wn[id];
for (int i = 0; i < n; i += len)
{
LL W = 1;
for (int j = 0; j < m; j++, W = W*unit%P)
{
LL p = S[i + j], q = S[i + j + m];
S[i + j] = (p + W*q) % P;
S[i + j + m] = ((p - W*q) % P + P) % P;
}
}
}
if (op == -1)
{
for (int i = 1; i < (n >> 1); i++)
swap(S[i], S[n - i]);
LL INV = quickly_mod(n, P - 2, P);
for (int i = 0; i < n; i++)
S[i] = S[i] * INV % P; //除以N----INV是逆元
}
}
LL A[MAXN], B[MAXN];
char s[MAXN];
int ans[MAXN];
int main()
{
getWn();
while (~scanf("%s", s))
{
int n = 0;
memset(A, 0, sizeof(A));
memset(B, 0, sizeof(B));
int len = strlen(s);
n = len;
for (int i = 0; i < len; i++)
A[i] = s[len - i - 1] - 48;
scanf("%s", s);
len = strlen(s);
for (int i = 0; i < len; i++)
B[i] = s[len - i - 1] - 48;
n = len > n ? len : n;
n = turn(n);
NTT(A, n << 1, 1);
NTT(B, n << 1, 1);
for (int i = 0; i < (n << 1); i++)
A[i] = (A[i] * B[i])%P;
NTT(A, n << 1, -1);
memset(ans, 0, sizeof(ans));
for (int i = 0; i < (n << 1); i++)
ans[i] = A[i];
for (int i = 0; i < (n << 1); i++)
{
ans[i + 1] += ans[i] / 10;
ans[i] %= 10;
}
for (len = n << 1; len > 0 && !ans[len]; len--);
for (; len >= 0; len--)
putchar(ans[len] + 48);
printf("\n");
}
return 0;
}