@xiaoziyao
2020-11-24T20:20:22.000000Z
字数 6457
阅读 1041
解题报告
CF809E Surprise me!解题报告:
虽然交不了题解了,但是还是记录一下这道大毒瘤题。
由于不太好求,我们考虑是一个排列,因此我们先设,即为每个值的位置,然后带进去:
我们首先根据欧拉函数的一个公式变一下式子(证明在题解末尾),并对于这个式子套路地反演一下(先忽略掉前面的):
套路地设,那么有:
对于前面的东西,我们设,这个东西很显然可以先线性筛出来和,然后暴力枚举倍数做到求出。
然后,我们把和带回原来的式子:
对于后面的东西,我们发现对于每个,用到的点只有个,因此,我们用到的总点数为个。
如果设为第次的总点数,我们有,这个数据范围启发我们来建立一颗虚树。
对于每个,我们暴力枚举它的所有倍数的点,建立一颗虚树,然后在上面跑树形。
具体地,我们设,为子树中的全体贡献,即,并设两个辅助数组。
考虑合并两颗子树和(是的儿子),那么答案为子树的贡献加子树的贡献加与两两之间的贡献,即。
把拆开:
的转移和的转移简单一些:
最后的答案即为。
由于建立虚树是的,所以总复杂度为
#include<stdio.h>
#include<vector>
#include<algorithm>
using namespace std;
const int maxn=200005,maxm=maxn<<1,mod=1000000007,maxk=25;
int n,e,cnt,ans,top,stp;
int a[maxn],start[maxn],to[maxm],then[maxm],f[maxn],P[maxn],p[maxn],pos[maxn],dfn[maxn],dep[maxn],fore[maxn][maxk],dp1[maxn],dp2[maxn],dp3[maxn],ok[maxn],stk[maxn],miu[maxn],phi[maxn],nphi[maxn];
vector<int>vp[maxn],g[maxn];
inline void add(int x,int y){
then[++e]=start[x],start[x]=e,to[e]=y;
}
inline void newadd(int x,int y){
g[x].push_back(y);
}
inline int cmp(int a,int b){
return dfn[a]<dfn[b];
}
void dfs(int x,int last){
dfn[x]=++cnt,dep[x]=dep[last]+1,fore[x][0]=last;
for(int i=1;i<=20;i++)
fore[x][i]=fore[fore[x][i-1]][i-1];
for(int i=start[x];i;i=then[i]){
int y=to[i];
if(y==last)
continue;
dfs(y,x);
}
}
int lca(int a,int b){
if(dep[a]<dep[b])
swap(a,b);
for(int i=20;i>=0;i--)
if(dep[fore[a][i]]>=dep[b])
a=fore[a][i];
if(a==b)
return a;
for(int i=20;i>=0;i--)
if(fore[a][i]!=fore[b][i])
a=fore[a][i],b=fore[b][i];
return fore[a][0];
}
void DP(int x,int k){
dp1[x]=dp2[x]=0,dp3[x]=a[x]%k==0? phi[a[x]]:0;
for(int i=0;i<g[x].size();i++){
int y=g[x][i],z=(dep[y]-dep[x]+mod)%mod;
DP(y,k);
dp1[x]+=((dp1[y]+1ll*(dp2[x]+1ll*dp3[x]*z%mod)%mod*dp3[y])%mod+1ll*dp2[y]*dp3[x]%mod)%mod;
dp2[x]+=(dp2[y]+1ll*dp3[y]*z%mod)%mod,dp3[x]+=dp3[y];
dp1[x]%=mod,dp2[x]%=mod,dp3[x]%=mod;
}
}
int calc(int k){
top=0,stp++;
for(int i=0;i<vp[k].size();i++)
ok[vp[k][i]]=stp;
sort(vp[k].begin(),vp[k].end(),cmp);
g[1].clear(),stk[++top]=1;
for(int i=0;i<vp[k].size();i++){
if(vp[k][i]==1)
continue;
int l=lca(vp[k][i],stk[top]);
if(l!=stk[top]){
while(top>0&&dfn[l]<dfn[stk[top-1]])
newadd(stk[top-1],stk[top]),top--;
if(dfn[l]>dfn[stk[top-1]])
g[l].clear(),newadd(l,stk[top]),top--,stk[++top]=l;
else newadd(stk[top-1],stk[top]),top--;
}
g[vp[k][i]].clear(),stk[++top]=vp[k][i];
}
for(int j=1;j<top;j++)
newadd(stk[j],stk[j+1]);
DP(1,k);
return dp1[1];
}
int ksm(int a,int b){
int res=1;
while(b){
if(b&1)
res=1ll*res*a%mod;
a=1ll*a*a%mod,b>>=1;
}
return res;
}
void sieve(){
p[1]=phi[1]=miu[1]=nphi[1]=1;
for(int i=2;i<=n;i++){
if(p[i]==0)
P[++cnt]=i,phi[i]=i-1,miu[i]=-1;
for(int j=1;j<=cnt;j++){
if(i*P[j]>n)
break;
p[i*P[j]]=1;
if(i%P[j]==0){
phi[i*P[j]]=phi[i]*P[j];
miu[i*P[j]]=0;
break;
}
phi[i*P[j]]=phi[i]*(P[j]-1);
miu[i*P[j]]=-miu[i];
}
nphi[i]=ksm(phi[i],mod-2);
}
}
int main(){
scanf("%d",&n);
sieve();
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
pos[a[i]]=i;
}
for(int i=1;i<n;i++){
int x,y;
scanf("%d%d",&x,&y);
add(x,y),add(y,x);
}
dfs(1,0);
for(int i=1;i<=n;i++)
for(int j=i;j<=n;j+=i){
f[j]=(f[j]+1ll*i*miu[j/i]%mod*nphi[i]%mod)%mod;
f[j]=(f[j]+mod)%mod;
}
for(int i=1;i<=n;i++)
for(int j=i;j<=n;j+=i)
vp[i].push_back(pos[j]);
for(int i=1;i<=n;i++)
ans+=2ll*f[i]*calc(i)%mod,ans%=mod;
printf("%d\n",(int)(1ll*ans*ksm(n,mod-2)%mod*ksm(n-1,mod-2)%mod));
return 0;
}
证明: