@RabbitHu
2017-08-30T08:28:18.000000Z
字数 3583
阅读 2024
笔记
名称 | 树链剖分 |
---|---|
概述 | 重链+线段树 |
特点 | 七个数组、只处理重链 |
应用 | 与树上路径有关的修改、询问 |
定义 | 内容 |
---|---|
重儿子 | 一个节点的最大的子树的根节点 |
轻儿子 | 不是重儿子的儿子 |
重边 | 连接一个节点与它的重儿子的边 |
轻边 | 不是重边的边 |
重链 | 全部由重边连成的链 |
重链头 | top[i]: i所在重链中深度最小的点 |
下面是一棵树,双线的是重边,单线的是轻边。
[1]
||
[2]
|| \
[3] [4]
/ \\ \\
[5] [6] [7]
\\
[8]
树链剖分后的树有如下性质:
1. 如果(u, v)是轻边,则 。
2. 根节点到一个节点的路径上,轻边个数不超过 ,重链条数也不超过。
3. 一个节点在且只在一条重链上,每条重链上的点深度递增。
将所有点放在一个区间上,要求每条重链上的点挨在一起。
如上图,可以做出这样的区间(5自身也是一条重链):
[1][2][3][6][8][5][4][7]
用线段树维护这个区间,如例题要求路径最大值和数值总和,可以用线段树维护区间最大值和区间和。
同线段树修改。用 pos[i] 记录树上的 i 号节点在区间中的位置(下标);
用 idx[i] 记录区间中的 i 号节点在树中的编号。
当修改树上 i 号节点的值时,就是在区间中修改 pos[i] 位置的值。
查询一条路径的最大值(或总和)时,将路径分为许多重边。
设我们要求(u, v)路径上的的最大值,伪代码如下:
//seg_path(u, v) 用线段树求 u, v 代表的重链区间的区间和
path_sum(u, v)
如果 u 和 v 不在同一条重链
如果 top[u] 比 top[v] 低
交换 u, v
返回 path_sum(u, fa[top[v]]) + seg_sum(top[v], v);
如果 u 和 v 在同一条重链
如果 u 比 v 低
交换 u, v
返回 seg_sum(u, v);
例如(还是这棵树):
[1]
||
[2]
|| \
[3] [4]
/ \\ \\
[5] [6] [7]
\\
[8]
如果要求(6, 7)最短路,则调用 path_sum(6, 7) 时求的是 path_sum(6, 2) + seg_sum(4, 7)。
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <cctype>
using namespace std;
int read(){
int sgn = 1;
char ch;
while(!isdigit(ch = getchar()))
if(ch == '-') sgn = -1;
int res = ch - '0';
while(isdigit(ch = getchar()))
res = res * 10 + ch - '0';
return sgn * res;
}
const int N = 30005, INF = 0x3f3f3f3f;
int n, Q, val[N];//val: 节点上的值
int ecnt, adj[N], nxt[2*N], go[2*N];
int tot, fa[N], dep[N], top[N], sze[N], son[N], pos[N], idx[N];
//数序列中元素个数;父亲;深度;节点所在重链头部;子树大小;重儿子;
//点在序列中的下标;序列中某元素在树中的编号
int SUM[4*N], MAX[4*N];//线段树
void add(int u, int v){
go[++ecnt] = v; nxt[ecnt] = adj[u]; adj[u] = ecnt;
go[++ecnt] = u; nxt[ecnt] = adj[v]; adj[v] = ecnt;
}
void seg_build(int k, int l, int r){
if(l == r) return (void)(SUM[k] = MAX[k] = val[idx[l]]);
int mid = (l + r) >> 1;
seg_build(k << 1, l, mid); seg_build(k << 1 | 1, mid + 1, r);
SUM[k] = SUM[k << 1] + SUM[k << 1 | 1];
MAX[k] = max(MAX[k << 1], MAX[k << 1 | 1]);
}
void seg_change(int k, int l, int r, int p, int x){
//将序列中第p个元素权值修改为x
if(l == r) return (void)(SUM[k] = MAX[k] = x);
int mid = (l + r) >> 1;
if(p <= mid) seg_change(k << 1, l, mid, p, x);
else seg_change(k << 1 | 1, mid + 1, r, p, x);
SUM[k] = SUM[k << 1] + SUM[k << 1 | 1];
MAX[k] = max(MAX[k << 1], MAX[k << 1 | 1]);
}
int seg_sum(int k, int l, int r, int ql, int qr){
if(l >= ql && r <= qr) return SUM[k];
int mid = (l + r) >> 1, res = 0;
if(ql <= mid) res += seg_sum(k << 1, l, mid, ql, qr);
if(qr > mid) res += seg_sum(k << 1 | 1, mid + 1, r, ql, qr);
return res;
}
int seg_max(int k, int l, int r, int ql, int qr){
if(l >= ql && r <= qr) return MAX[k];
int mid = (l + r) >> 1, res = -INF;
if(ql <= mid) res = max(res, seg_max(k << 1, l, mid, ql, qr));
if(qr > mid) res = max(res, seg_max(k << 1 | 1, mid + 1, r, ql, qr));
return res;
}
void init(){
int q[n], r, u, v;
dep[1] = 1, q[r = 0] = 1;
//共三次BFS:1下去、2上来、3下去
//1:正序处理dep和fa
for(int l = 0; l <= r; l++){
sze[u = q[l]] = 1;
for(int e = adj[u]; e; e = nxt[e]){
if(dep[v = go[e]]) continue;
dep[v] = dep[u] + 1, fa[v] = u;
q[++r] = v;
}
}
//2:倒序遍历队列处理sze和son
for(int l = r; l >= 0; l--){
sze[u = fa[q[l]]] += sze[v = q[l]];
if(sze[v] > sze[son[u]]) son[u] = v;
}
//3:正序处理序列(idx, pos, top)
for(int l = 0; l <= r; l++){
if(top[u = q[l]]) continue;
for(int v = u; v; v = son[v])
idx[pos[v] = ++tot] = v, top[v] = u;
}
seg_build(1, 1, n);
}
int path_sum(int u, int v){
if(top[u] != top[v]){
if(dep[top[u]] > dep[top[v]]) swap(u, v); //使top[u]深度较小
return seg_sum(1, 1, n, pos[top[v]], pos[v]) + path_sum(u, fa[top[v]]);
}
if(dep[u] > dep[v]) swap(u, v); //使u深度较小
return seg_sum(1, 1, n, pos[u], pos[v]);
}
int path_max(int u, int v){
if(top[u] != top[v]){
if(dep[top[u]] > dep[top[v]]) swap(u, v); //使top[u]深度较小
return max(seg_max(1, 1, n, pos[top[v]], pos[v]), path_max(u, fa[top[v]]));
}
if(dep[u] > dep[v]) swap(u, v); //使u深度较小
return seg_max(1, 1, n, pos[u], pos[v]);
}
int main(){
n = read();
for(int i = 1; i < n; i++)
add(read(), read());
for(int i = 1; i <= n; i++)
val[i] = read();
init();
Q = read();
int u, v;
char op;
while(Q--){
while((op = getchar()) != 'C' && op != 'Q');
if(op == 'Q') op = getchar();
u = read(), v = read();
if(op == 'C') seg_change(1, 1, n, pos[u], v);
else if(op == 'S') printf("%d\n", path_sum(u, v));
else if(op == 'M') printf("%d\n", path_max(u, v));
}
return 0;
}