@RabbitHu
2017-08-30T00:28:18.000000Z
字数 3583
阅读 2336
笔记
| 名称 | 树链剖分 |
|---|---|
| 概述 | 重链+线段树 |
| 特点 | 七个数组、只处理重链 |
| 应用 | 与树上路径有关的修改、询问 |
| 定义 | 内容 |
|---|---|
| 重儿子 | 一个节点的最大的子树的根节点 |
| 轻儿子 | 不是重儿子的儿子 |
| 重边 | 连接一个节点与它的重儿子的边 |
| 轻边 | 不是重边的边 |
| 重链 | 全部由重边连成的链 |
| 重链头 | top[i]: i所在重链中深度最小的点 |
下面是一棵树,双线的是重边,单线的是轻边。
[1]||[2]|| \[3] [4]/ \\ \\[5] [6] [7]\\[8]
树链剖分后的树有如下性质:
1. 如果(u, v)是轻边,则 。
2. 根节点到一个节点的路径上,轻边个数不超过 ,重链条数也不超过。
3. 一个节点在且只在一条重链上,每条重链上的点深度递增。
将所有点放在一个区间上,要求每条重链上的点挨在一起。
如上图,可以做出这样的区间(5自身也是一条重链):
[1][2][3][6][8][5][4][7]
用线段树维护这个区间,如例题要求路径最大值和数值总和,可以用线段树维护区间最大值和区间和。
同线段树修改。用 pos[i] 记录树上的 i 号节点在区间中的位置(下标);
用 idx[i] 记录区间中的 i 号节点在树中的编号。
当修改树上 i 号节点的值时,就是在区间中修改 pos[i] 位置的值。
查询一条路径的最大值(或总和)时,将路径分为许多重边。
设我们要求(u, v)路径上的的最大值,伪代码如下:
//seg_path(u, v) 用线段树求 u, v 代表的重链区间的区间和path_sum(u, v)如果 u 和 v 不在同一条重链如果 top[u] 比 top[v] 低交换 u, v返回 path_sum(u, fa[top[v]]) + seg_sum(top[v], v);如果 u 和 v 在同一条重链如果 u 比 v 低交换 u, v返回 seg_sum(u, v);
例如(还是这棵树):
[1]||[2]|| \[3] [4]/ \\ \\[5] [6] [7]\\[8]
如果要求(6, 7)最短路,则调用 path_sum(6, 7) 时求的是 path_sum(6, 2) + seg_sum(4, 7)。
#include <cstdio>#include <cmath>#include <cstring>#include <algorithm>#include <iostream>#include <cctype>using namespace std;int read(){int sgn = 1;char ch;while(!isdigit(ch = getchar()))if(ch == '-') sgn = -1;int res = ch - '0';while(isdigit(ch = getchar()))res = res * 10 + ch - '0';return sgn * res;}const int N = 30005, INF = 0x3f3f3f3f;int n, Q, val[N];//val: 节点上的值int ecnt, adj[N], nxt[2*N], go[2*N];int tot, fa[N], dep[N], top[N], sze[N], son[N], pos[N], idx[N];//数序列中元素个数;父亲;深度;节点所在重链头部;子树大小;重儿子;//点在序列中的下标;序列中某元素在树中的编号int SUM[4*N], MAX[4*N];//线段树void add(int u, int v){go[++ecnt] = v; nxt[ecnt] = adj[u]; adj[u] = ecnt;go[++ecnt] = u; nxt[ecnt] = adj[v]; adj[v] = ecnt;}void seg_build(int k, int l, int r){if(l == r) return (void)(SUM[k] = MAX[k] = val[idx[l]]);int mid = (l + r) >> 1;seg_build(k << 1, l, mid); seg_build(k << 1 | 1, mid + 1, r);SUM[k] = SUM[k << 1] + SUM[k << 1 | 1];MAX[k] = max(MAX[k << 1], MAX[k << 1 | 1]);}void seg_change(int k, int l, int r, int p, int x){//将序列中第p个元素权值修改为xif(l == r) return (void)(SUM[k] = MAX[k] = x);int mid = (l + r) >> 1;if(p <= mid) seg_change(k << 1, l, mid, p, x);else seg_change(k << 1 | 1, mid + 1, r, p, x);SUM[k] = SUM[k << 1] + SUM[k << 1 | 1];MAX[k] = max(MAX[k << 1], MAX[k << 1 | 1]);}int seg_sum(int k, int l, int r, int ql, int qr){if(l >= ql && r <= qr) return SUM[k];int mid = (l + r) >> 1, res = 0;if(ql <= mid) res += seg_sum(k << 1, l, mid, ql, qr);if(qr > mid) res += seg_sum(k << 1 | 1, mid + 1, r, ql, qr);return res;}int seg_max(int k, int l, int r, int ql, int qr){if(l >= ql && r <= qr) return MAX[k];int mid = (l + r) >> 1, res = -INF;if(ql <= mid) res = max(res, seg_max(k << 1, l, mid, ql, qr));if(qr > mid) res = max(res, seg_max(k << 1 | 1, mid + 1, r, ql, qr));return res;}void init(){int q[n], r, u, v;dep[1] = 1, q[r = 0] = 1;//共三次BFS:1下去、2上来、3下去//1:正序处理dep和fafor(int l = 0; l <= r; l++){sze[u = q[l]] = 1;for(int e = adj[u]; e; e = nxt[e]){if(dep[v = go[e]]) continue;dep[v] = dep[u] + 1, fa[v] = u;q[++r] = v;}}//2:倒序遍历队列处理sze和sonfor(int l = r; l >= 0; l--){sze[u = fa[q[l]]] += sze[v = q[l]];if(sze[v] > sze[son[u]]) son[u] = v;}//3:正序处理序列(idx, pos, top)for(int l = 0; l <= r; l++){if(top[u = q[l]]) continue;for(int v = u; v; v = son[v])idx[pos[v] = ++tot] = v, top[v] = u;}seg_build(1, 1, n);}int path_sum(int u, int v){if(top[u] != top[v]){if(dep[top[u]] > dep[top[v]]) swap(u, v); //使top[u]深度较小return seg_sum(1, 1, n, pos[top[v]], pos[v]) + path_sum(u, fa[top[v]]);}if(dep[u] > dep[v]) swap(u, v); //使u深度较小return seg_sum(1, 1, n, pos[u], pos[v]);}int path_max(int u, int v){if(top[u] != top[v]){if(dep[top[u]] > dep[top[v]]) swap(u, v); //使top[u]深度较小return max(seg_max(1, 1, n, pos[top[v]], pos[v]), path_max(u, fa[top[v]]));}if(dep[u] > dep[v]) swap(u, v); //使u深度较小return seg_max(1, 1, n, pos[u], pos[v]);}int main(){n = read();for(int i = 1; i < n; i++)add(read(), read());for(int i = 1; i <= n; i++)val[i] = read();init();Q = read();int u, v;char op;while(Q--){while((op = getchar()) != 'C' && op != 'Q');if(op == 'Q') op = getchar();u = read(), v = read();if(op == 'C') seg_change(1, 1, n, pos[u], v);else if(op == 'S') printf("%d\n", path_sum(u, v));else if(op == 'M') printf("%d\n", path_max(u, v));}return 0;}