【模板】树链剖分(March. 1st, 2019)
算法
题目来源
代码
//La Pluie#include<cstdio>#include<cstring>#include<iostream>#define lson u << 1, l, mid#define rson u << 1 | 1, mid + 1, r#define For(i, x, y) for(int i = x; (x > y ? (i >= y) : (i <= y)); i += (x > y ? -1 : 1))using namespace std;const int N = 1e5 + 5; struct edge{ int v, next; } e[N << 1]; struct tree{ int l, r, s, lazy; } t[N << 2];int n, q, k = 1, root, mod, pos, val[N], dep[N], ftr[N], size[N], son[N], top[N], tid[N], rak[N], sum[N], head[N];void adde(int u, int v){ e[k] = (edge){v, head[u]}; head[u] = k++; }void pushdown(int u, int m){ if(t[u]. lazy){ t[u << 1]. s = (t[u << 1]. s + t[u]. lazy * (m - (m >> 1))) % mod; t[u << 1]. lazy = (t[u << 1]. lazy + t[u]. lazy) % mod; t[u << 1 | 1]. s = (t[u << 1 | 1]. s + t[u]. lazy * (m >> 1)) % mod; t[u << 1 | 1]. lazy = (t[u << 1 | 1]. lazy + t[u]. lazy) % mod; t[u]. lazy = 0; }}void dfs(int u, int fa){ dep[u] = dep[fa] + 1; ftr[u] = fa; size[u] = 1; sum[u] = val[u]; for(int i = head[u]; i != -1; i = e[i]. next){ int v = e[i]. v; if(v == fa) continue; dfs(v, u); size[u] += size[v]; sum[u] = (sum[u] + sum[v]) % mod; if(son[u] == -1 || size[v] > size[son[u]]) son[u] = v; }}void getpos(int u, int rt){ top[u] = rt; tid[u] = ++pos; rak[tid[u]] = u; if(son[u] == -1) return; getpos(son[u], rt); for(int i = head[u]; i != -1; i = e[i]. next){ int v = e[i]. v; if(v != son[u] && v != ftr[u]) getpos(v, v); }}void pushup(int u){ t[u]. s = (t[u << 1]. s + t[u << 1 | 1]. s) % mod; }void build(int u, int l, int r){ t[u]. l = l; t[u]. r = r; t[u]. s = t[u]. lazy = 0; if(l == r){t[u]. s = val[rak[l]]; return; } int mid = l + r >> 1; build(lson); build(rson); pushup(u);}void update(int u, int l, int r, int del){ if(l <= t[u]. l && t[u]. r <= r){ t[u]. s += del * ( t[u]. r - t[u]. l + 1); t[u]. lazy += del; return; } pushdown(u, t[u]. r - t[u]. l + 1); int mid = t[u]. l + t[u]. r >> 1; if(l <= mid) update(u << 1, l, r, del); if(r > mid) update(u << 1 | 1, l, r, del); pushup(u);}int query(int u, int l, int r){ if(l <= t[u]. l && t[u]. r <= r) return t[u]. s; pushdown(u, t[u]. r - t[u]. l + 1); int res = 0; int mid = t[u]. l + t[u]. r >> 1; if(l <= mid) res += query(u << 1, l, r), res %= mod; if(r > mid) res += query(u << 1 | 1, l, r), res %= mod; return res;}void uplca(int x, int y, int z){ int fx = top[x], fy = top[y]; while(fx != fy){ if(dep[fx] < dep[fy]) swap(x, y), swap(fx, fy); update(1, tid[fx], tid[x], z); x = ftr[fx], fx = top[x]; } if(dep[x] > dep[y]) swap(x, y); update(1, tid[x], tid[y], z);}int qulca(int x, int y){ int res = 0; int fx = top[x], fy = top[y]; while(fx != fy){ if(dep[fx] < dep[fy]) swap(x, y), swap(fx, fy); res += query(1, tid[fx], tid[x]); res %= mod; x = ftr[fx], fx = top[x]; } if(dep[x] > dep[y]) swap(x, y); return (res + query(1, tid[x], tid[y])) % mod;}int main(){ memset(head, -1, sizeof(head)); memset(son, -1, sizeof(son)); scanf("%d %d %d %d", &n, &q, &root, &mod); For(i, 1, n) scanf("%d", &val[i]), val[i] %= mod; For(i, 1, n - 1){ int u, v; scanf("%d %d", &u, &v); adde(u, v); adde(v, u); } dfs(root, 0); getpos(root, root); build(1, 1, pos); while(q--){ int ctrl, x, y, z; scanf("%d", &ctrl); if(ctrl == 1){ scanf("%d %d %d", &x, &y, &z); uplca(x, y, z % mod); } else if(ctrl == 2){ scanf("%d %d", &x, &y); printf("%d\n", qulca(x, y)); } else if(ctrl == 3){ scanf("%d %d", &x, &z); update(1, tid[x], tid[x] + size[x] - 1, z % mod); } else if(ctrl == 4){ scanf("%d", &x); printf("%d\n", query(1, tid[x], tid[x] + size[x] - 1)); } } return 0;}