@Chilling
2017-02-18T11:18:13.000000Z
字数 3881
阅读 1015
树链剖分
线段树
Description
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作:
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意:从点u到点v的路径上的节点包括u和v本身
Input
输入的第一行为一个整数n,表示节点的个数。接下来n-1行,每行2个整数a和b,表示节点a和节点b之间有一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作的总数。
接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
Output
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
Sample Input
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
Sample Output
4
1
2
2
10
6
5
6
5
16
部分摘自starszys的博客,但是这个博客上面的题是边权,因此以下内容有一定的修改……
树链,就是树上的路径。剖分,就是把路径分类为重链和轻链。
重儿子:siz[u]为v的子节点中siz值最大的,那么u就是v的重儿子。
轻儿子:v的其它子节点。
重边:点v与其重儿子的连边。
轻边:点v与其轻儿子的连边。
重链:由重边连成的路径。
轻链:轻边。
剖分后的树有如下性质:
性质1:如果(v,u)为轻边,则siz[u] * 2 < siz[v];
性质2:从根到某一点的路径上轻链、重链的个数都不大于logn。
分析:对于这道题来说,我们应该求出如下内容:
siz[u]
表示以u为根的子树的节点数;
dep[u]
表示u的深度(根深度为1);
fa[u]
表示u的父亲,son[u]
表示与u在同一重链上的u的儿子节点(重儿子);
top[u]
表示u所在的重链的顶端节点;
pre[u]
表示u的新编号。dfs1:求出siz[u],dep[u],fa[u],son[u];
dfs2:对于某点u,若它不是叶子节点,显然有top[son[u]] = top[u];节点u重新编号为pre[u]之后,valnow[pre[u]]=val[u];此时,为了使一条重链各边在线段树中连续分布,应当进行dfs2(son[u],u,id)。(id为重链顶端节点)
而对于u个各个轻儿子,top[v]=v。询问:
- 询问(u,v)路径上的最值,首先需要判断u和v是否在同一条重链上,若在同一条重链上,使u深v浅,查询(pre[v],pre[u])区间内的最值即可;
- 若u,v不在同一条重链上,那么判断u和v重链顶端的深度,使top[u]更深,先查询pre[top[u]]到pre[u],那么u点所在的这部分重链就已经查询完毕了;接着u等于fa[top[u]],继续以上操作,直到u和v在同一条重链上。
- 查询部分即线段树。
#include<stdio.h>
#include<vector>
#include<algorithm>
#include<string.h>
#define INF 0x3f3f3f3f
const int maxn=30005;
using namespace std;
int n;
int val[maxn];
int siz[maxn],dep[maxn],fa[maxn],son[maxn];//dfs1
int pre[maxn];//对于节点v,pre[v]对应新编号++clk
int valnow[maxn];//节点新编号后对应的值 valnow[clk]=val[v]
int top[maxn];//v节点所在重链顶端的点
int clk;
struct node
{
int l,r,maxs,sum;
}s[maxn*4];
vector<int> V[maxn];
void dfs1(int u,int f,int d)//维护fa,dep,siz,son
{
dep[u]=d; //深度
fa[u]=f; //u的父亲
siz[u]=1; //以u为根的子树节点
//son[u]=0; //u的重儿子
for(int i=0;i<V[u].size();i++)
{
int v=V[u][i];
if(v==f) continue;
dfs1(v,u,d+1);
if(siz[v]>siz[son[u]]) //更新重儿子
son[u]=v;
siz[u]+=siz[v];
}
}
void dfs2(int u,int f,int id) //id为u所在的顶端节点
{
top[u]=id;
pre[u]=++clk;
valnow[clk]=val[u];
if(!son[u]) return; //叶子节点,无重儿子
dfs2(son[u],u,id);
for(int i=0;i<V[u].size();i++)
{
int v=V[u][i];
if(v==f||v==son[u]) continue;
dfs2(v,u,v);
}
}
void build(int id,int l,int r)
{
s[id].l=l;
s[id].r=r;
if(l==r)
s[id].maxs=s[id].sum=valnow[l];
else
{
int mid=(l+r)/2;
build(id*2,l,mid);
build(id*2+1,mid+1,r);
s[id].maxs=max(s[id*2].maxs,s[id*2+1].maxs);
s[id].sum=s[id*2].sum+s[id*2+1].sum;
}
}
void rep(int id,int x,int num)
{
if(s[id].l==s[id].r)
s[id].maxs=s[id].sum=num;
else
{
int mid=(s[id].l+s[id].r)/2;
if(x<=mid)
rep(id*2,x,num);
else
rep(id*2+1,x,num);
s[id].maxs=max(s[id*2].maxs,s[id*2+1].maxs);
s[id].sum=s[id*2].sum+s[id*2+1].sum;
}
}
int qmax(int id,int l,int r)
{
if(l<=s[id].l&&s[id].r<=r)
return s[id].maxs;
int mid=(s[id].l+s[id].r)/2;
if(r<=mid)
return qmax(id*2,l,r);
else if(l>mid)
return qmax(id*2+1,l,r);
else
return max(qmax(id*2,l,r),qmax(id*2+1,l,r));
}
int qsum(int id,int l,int r)
{
if(l<=s[id].l&&s[id].r<=r)
return s[id].sum;
int mid=(s[id].l+s[id].r)/2;
if(r<=mid)
return qsum(id*2,l,r);
else if(l>mid)
return qsum(id*2+1,l,r);
else
return qsum(id*2,l,r)+qsum(id*2+1,l,r);
}
int ansmax(int u,int v)
{
int ans=-INF;
while(top[u]!=top[v])//不在一条重链上
{
if(dep[top[u]]<dep[top[v]])
swap(u,v); //u重链顶端更深
ans=max(ans,qmax(1,pre[top[u]],pre[u]));
u=fa[top[u]];
}
if(dep[u]<dep[v])
swap(u,v);
ans=max(ans,qmax(1,pre[v],pre[u]));
return ans;
}
int anssum(int u,int v)
{
int ans=0;
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]])
swap(u,v); //u重链顶端更深
ans+=qsum(1,pre[top[u]],pre[u]);
u=fa[top[u]];
}
if(dep[u]<dep[v])
swap(u,v);
ans+=qsum(1,pre[v],pre[u]);
return ans;
}
int main()
{
int x,y,q;
char ch[9];
scanf("%d",&n);
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
V[x].push_back(y);
V[y].push_back(x);
}
for(int i=1;i<=n;i++)
scanf("%d",&val[i]);
dfs1(1,0,1);
dfs2(1,0,1);
build(1,1,n);
scanf("%d",&q);
while(q--)
{
scanf("%s%d%d",ch,&x,&y);
if(ch[3]=='N')
rep(1,pre[x],y);
else if(ch[3]=='X')
printf("%d\n",ansmax(x,y));
else
printf("%d\n",anssum(x,y));
}
return 0;
}