[关闭]
@RabbitHu 2017-08-30T08:28:18.000000Z 字数 3583 阅读 2024

树链剖分——你为什么这么树链呢?

笔记


名称 树链剖分
概述 重链+线段树
特点 七个数组、只处理重链
应用 与树上路径有关的修改、询问

例题 BZOJ 1036

讲解

与树剖有关的定义

定义 内容
重儿子 一个节点的最大的子树的根节点
轻儿子 不是重儿子的儿子
重边 连接一个节点与它的重儿子的边
轻边 不是重边的边
重链 全部由重边连成的链
重链头 top[i]: i所在重链中深度最小的点

下面是一棵树,双线的是重边,单线的是轻边。

  1. [1]
  2. ||
  3. [2]
  4. || \
  5. [3] [4]
  6. / \\ \\
  7. [5] [6] [7]
  8. \\
  9. [8]

树链剖分后的树有如下性质:
1. 如果(u, v)是轻边,则
2. 根节点到一个节点的路径上,轻边个数不超过 ,重链条数也不超过
3. 一个节点在且只在一条重链上,每条重链上的点深度递增。

与树剖有关的操作

1. 预处理

将所有点放在一个区间上,要求每条重链上的点挨在一起。
如上图,可以做出这样的区间(5自身也是一条重链):

  1. [1][2][3][6][8][5][4][7]

用线段树维护这个区间,如例题要求路径最大值和数值总和,可以用线段树维护区间最大值和区间和。

2. 修改

同线段树修改。用 pos[i] 记录树上的 i 号节点在区间中的位置(下标);
用 idx[i] 记录区间中的 i 号节点在树中的编号。
当修改树上 i 号节点的值时,就是在区间中修改 pos[i] 位置的值。

3. 查询

查询一条路径的最大值(或总和)时,将路径分为许多重边。

设我们要求(u, v)路径上的的最大值,伪代码如下:

  1. //seg_path(u, v) 用线段树求 u, v 代表的重链区间的区间和
  2. path_sum(u, v)
  3. 如果 u v 不在同一条重链
  4. 如果 top[u] top[v]
  5. 交换 u, v
  6. 返回 path_sum(u, fa[top[v]]) + seg_sum(top[v], v);
  7. 如果 u v 在同一条重链
  8. 如果 u v
  9. 交换 u, v
  10. 返回 seg_sum(u, v);

例如(还是这棵树):

  1. [1]
  2. ||
  3. [2]
  4. || \
  5. [3] [4]
  6. / \\ \\
  7. [5] [6] [7]
  8. \\
  9. [8]

如果要求(6, 7)最短路,则调用 path_sum(6, 7) 时求的是 path_sum(6, 2) + seg_sum(4, 7)。

代码

  1. #include <cstdio>
  2. #include <cmath>
  3. #include <cstring>
  4. #include <algorithm>
  5. #include <iostream>
  6. #include <cctype>
  7. using namespace std;
  8. int read(){
  9. int sgn = 1;
  10. char ch;
  11. while(!isdigit(ch = getchar()))
  12. if(ch == '-') sgn = -1;
  13. int res = ch - '0';
  14. while(isdigit(ch = getchar()))
  15. res = res * 10 + ch - '0';
  16. return sgn * res;
  17. }
  18. const int N = 30005, INF = 0x3f3f3f3f;
  19. int n, Q, val[N];//val: 节点上的值
  20. int ecnt, adj[N], nxt[2*N], go[2*N];
  21. int tot, fa[N], dep[N], top[N], sze[N], son[N], pos[N], idx[N];
  22. //数序列中元素个数;父亲;深度;节点所在重链头部;子树大小;重儿子;
  23. //点在序列中的下标;序列中某元素在树中的编号
  24. int SUM[4*N], MAX[4*N];//线段树
  25. void add(int u, int v){
  26. go[++ecnt] = v; nxt[ecnt] = adj[u]; adj[u] = ecnt;
  27. go[++ecnt] = u; nxt[ecnt] = adj[v]; adj[v] = ecnt;
  28. }
  29. void seg_build(int k, int l, int r){
  30. if(l == r) return (void)(SUM[k] = MAX[k] = val[idx[l]]);
  31. int mid = (l + r) >> 1;
  32. seg_build(k << 1, l, mid); seg_build(k << 1 | 1, mid + 1, r);
  33. SUM[k] = SUM[k << 1] + SUM[k << 1 | 1];
  34. MAX[k] = max(MAX[k << 1], MAX[k << 1 | 1]);
  35. }
  36. void seg_change(int k, int l, int r, int p, int x){
  37. //将序列中第p个元素权值修改为x
  38. if(l == r) return (void)(SUM[k] = MAX[k] = x);
  39. int mid = (l + r) >> 1;
  40. if(p <= mid) seg_change(k << 1, l, mid, p, x);
  41. else seg_change(k << 1 | 1, mid + 1, r, p, x);
  42. SUM[k] = SUM[k << 1] + SUM[k << 1 | 1];
  43. MAX[k] = max(MAX[k << 1], MAX[k << 1 | 1]);
  44. }
  45. int seg_sum(int k, int l, int r, int ql, int qr){
  46. if(l >= ql && r <= qr) return SUM[k];
  47. int mid = (l + r) >> 1, res = 0;
  48. if(ql <= mid) res += seg_sum(k << 1, l, mid, ql, qr);
  49. if(qr > mid) res += seg_sum(k << 1 | 1, mid + 1, r, ql, qr);
  50. return res;
  51. }
  52. int seg_max(int k, int l, int r, int ql, int qr){
  53. if(l >= ql && r <= qr) return MAX[k];
  54. int mid = (l + r) >> 1, res = -INF;
  55. if(ql <= mid) res = max(res, seg_max(k << 1, l, mid, ql, qr));
  56. if(qr > mid) res = max(res, seg_max(k << 1 | 1, mid + 1, r, ql, qr));
  57. return res;
  58. }
  59. void init(){
  60. int q[n], r, u, v;
  61. dep[1] = 1, q[r = 0] = 1;
  62. //共三次BFS:1下去、2上来、3下去
  63. //1:正序处理dep和fa
  64. for(int l = 0; l <= r; l++){
  65. sze[u = q[l]] = 1;
  66. for(int e = adj[u]; e; e = nxt[e]){
  67. if(dep[v = go[e]]) continue;
  68. dep[v] = dep[u] + 1, fa[v] = u;
  69. q[++r] = v;
  70. }
  71. }
  72. //2:倒序遍历队列处理sze和son
  73. for(int l = r; l >= 0; l--){
  74. sze[u = fa[q[l]]] += sze[v = q[l]];
  75. if(sze[v] > sze[son[u]]) son[u] = v;
  76. }
  77. //3:正序处理序列(idx, pos, top)
  78. for(int l = 0; l <= r; l++){
  79. if(top[u = q[l]]) continue;
  80. for(int v = u; v; v = son[v])
  81. idx[pos[v] = ++tot] = v, top[v] = u;
  82. }
  83. seg_build(1, 1, n);
  84. }
  85. int path_sum(int u, int v){
  86. if(top[u] != top[v]){
  87. if(dep[top[u]] > dep[top[v]]) swap(u, v); //使top[u]深度较小
  88. return seg_sum(1, 1, n, pos[top[v]], pos[v]) + path_sum(u, fa[top[v]]);
  89. }
  90. if(dep[u] > dep[v]) swap(u, v); //使u深度较小
  91. return seg_sum(1, 1, n, pos[u], pos[v]);
  92. }
  93. int path_max(int u, int v){
  94. if(top[u] != top[v]){
  95. if(dep[top[u]] > dep[top[v]]) swap(u, v); //使top[u]深度较小
  96. return max(seg_max(1, 1, n, pos[top[v]], pos[v]), path_max(u, fa[top[v]]));
  97. }
  98. if(dep[u] > dep[v]) swap(u, v); //使u深度较小
  99. return seg_max(1, 1, n, pos[u], pos[v]);
  100. }
  101. int main(){
  102. n = read();
  103. for(int i = 1; i < n; i++)
  104. add(read(), read());
  105. for(int i = 1; i <= n; i++)
  106. val[i] = read();
  107. init();
  108. Q = read();
  109. int u, v;
  110. char op;
  111. while(Q--){
  112. while((op = getchar()) != 'C' && op != 'Q');
  113. if(op == 'Q') op = getchar();
  114. u = read(), v = read();
  115. if(op == 'C') seg_change(1, 1, n, pos[u], v);
  116. else if(op == 'S') printf("%d\n", path_sum(u, v));
  117. else if(op == 'M') printf("%d\n", path_max(u, v));
  118. }
  119. return 0;
  120. }
添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注