树链剖分/轻重链剖分(Heavy-Light Decomposition) - 如何在树上优雅地跳跃(施工中)

​ 树链剖分一般指重链剖分,是处理树上问题的一个有效手段。

​ 从某种意义上来讲,树链剖分是一种优雅的暴力

​ 常规意义上的树链剖分一般指重链剖分,实际上也可指长链剖分、虚实链剖分。在这篇博客中,若不特加说明,树链剖分均指前者,因为重链剖分的应用范围相对来说较为广泛。

概念

​ 树链剖分,顾名思义,是将一个树划分为多条链,并且以链为单位来完成查询/修改操作

​ 假设树上某节点u,那么u的各子树(这里指以u的各孩子节点为根的树)中包含节点数目最多的子树所对应的孩子称为u的重孩子,定义为son[u]。

​ 连接u和son[u]的边称为重边,由多条连续重边构成的链称为重链。相应地,非重边的边称为轻边。若一条轻边连接着一个叶子节点,那么这个叶子节点可视为位于长度为0的重链上。

​ 下面把以x为根的树中所包含的节点数目称为x的size。

性质

​ 假设一棵树的根节点为r,并且r有k个孩子。那么显然:其重孩子的size至少为\(\frac{1}{k}*size[r](k>1)\)​。

每次经过一条轻边,所到达子树的大小至少折半。原因如下:r拥有轻孩子的条件是r的孩子数目大于2,假设连接的是轻边,显然这条轻边连接的轻孩子的size必定不大于\(\frac{size[r]}{2}\),否则它将成为重孩子,即与假设相矛盾。

​ 根据上述性质,我们可以知道:由根节点出发到树上任一节点,经过的轻边数目必然小于\(log_2n\)。而重边只能插入到这条路径的两条轻边之间或是边界位置,且连续重边会形成一条重链。因此,这条路径上重链的数目必然不可能超过\(log_2n+2\)​。

​ 由上述结论可知,对于任意一条树上路径,其经过的轻重链条数的数量级必定至多为\(O(logn)\)级别,也就意味着,利用树剖维护树上路径的区间信息时,每次处理整条路径最多只需进行\(O(logn)\)​次区间操作

​ 由于我们每次操作均以重链为单位进行区间操作,假设每次区间操作的复杂度为\(O(k)\)​。由此可以推算得到:包括预处理在内的树链剖分的总时间复杂度为\(O(n+qklogn)\)​​​​​​,其中q为询问次数。

代码实现

​ 选择树的某一节点作为其根节点(一般来说是可以任选的,具体题目具体分析),而后进行两次DFS。

​ 第一次DFS,处理出子树的大小size以及深度dep等信息,并得到所有节点的重孩子son。

​ 第二次DFS,求该树的DFS序,并求得每个节点所在重链的top(重链上深度最小的节点)。由于搜索时优先搜索重孩子。这保证了任意一条重链在求得的DFS序上必然是连续的,因而对于每条重链操作时,都相当于对dfs序上的一个连续区间进行操作,而简单的区间操作正是我们所擅长的。

​ 考虑从x到y的一条树上路径。那么显然,这条路径可以被分解为若干条重链以及轻边所组成的区间。我们每次处理x所在当前链或者y所在的当前链、对其进行区间操作,而后使其跳转到下一条待处理链上。容易想到的问题是,如果点每次以链为单位进行跳转,由于 \(lca(x,y)\) 可能并不在重链的端点上,因此点有可能会跳到 \(lca(x,y)\)​ 的上面去。

​ 所以我们了解到一点:横跨 \(lca(x,y)\)​ 的重链可能导致点“跳过头”。

​ 我们不妨考虑这样一种处理方式:设dep[top]较大的点为\(z = dep[top[x]]>dep[top[y]]\) ? \(x\) : \(y\),优先处理z所在的重链、对该重链对应的dfs序区间进行操作,而后跳转到par[top[z]],不断重复上述过程(每一轮的 \(z\) 不一定相同),直至top[x] == top[y],结束。

​ 由重孩子的定义,可知x到y的路径上至多有一条横跨 \(lca(x,y)\) 的重链。因此,另一条链的top的深度必定大于 \(lca(x,y)\) 的深度,即在 \(lca(x,y)\) 的下方。所以每次处理 \(z\) 所在的重链可以保证不会使任意一点“跳过头”。

​ 同时,每次操作 \(z\) 都将向上爬升,而 \(z\) 显然只能进行有限次跳转。所以,在x和y中必然会有一个点率先到达 \(lca(x,y)\)。当某个点率先到达 \(lca(x,y)\) 后,另一个点继续按照上面的步骤迭代跳转,直到top[x] == top[y],即结束迭代过程时,这个点与 \(lca(x,y)\) 位于同一条重链上。

​ 可知此时x和y中深度较小的节点即为 \(lca(x,y)\)​, \(lca(x,y)\)​ 虽不一定是重链的一端,但无伤大雅:x和y在同一条重链上就意味着此时x到y的路径在dfs序中位于连续区间上。设 \(rk[p]\)​ 为节点 \(p\)​ 在dfs序中的位置,且此时 \(dep[x]<dep[y]\)​ ,那么最后一步只需处理\([rk[x],rk[y]]\)​​​​ 这个区间即可。

从上面的描述可以知道,树剖可以用来求LCA(迭代跳转过程结束后,深度较小的节点即为 \(lca(x,y)\) ),但它的作用远不止于此。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
void dfs1(int u,int fa){  //fa代表u的父亲
par[u] = fa;
dep[u] = dep[fa]+1;
size[u] = 1;
for(int i = head[u] ; i ; i = E[i].nx){
int& to = E[i].to;
if(to == fa)continue;
else dfs1(to,u);
size[u] += size[to];
if(size[to] > size[son[u]])son[u] = to;
}
}

int top[N]; //top[u]表示u所在重链的起始位置
int rk[N]; //rk[u]表示u在dfs序中的位置
int pid[N]; //pid[u]表示dfs序中
void dfs2(int u,int tp){
static int cnt = 0;
top[u] = tp;
rk[u] = ++cnt;
pid[cnt] = u;
if(son[u])dfs2(son[u],tp);
else return;
for(int i = head[u] ; i ; i = E[i].nx){
int& to = E[i].to;
if(to == par[u] || to == son[u])continue;
dfs2(to,to);
}
}

void func(int st,int ed){ //对于单条重链的区间操作
...
}

void dcp(){ //dcp意为decomposition,表示把x到y的树上路径划分为由若干条链组成的区间。
while(top[x] != top[y]){
int& z = dep[top[x]]>dep[top[y]] ? x : y;
func(rk[top[z]],rk[z]);
z = par[top[z]];
}
//在结束此轮循环后,x和y已经跳到同一条重链上了。
dep[x] < dep[y] ? func(x,y) : func(y,x);
}

例题

​ 限于篇幅原因,下列代码段仅给出AC代码中与树剖相关的内容,便于集中注意力理解树链剖分使用的过程。其余的部分重要程度较低,因此在此略过。

​ 若需参照,结合上方的树剖模板即可。希望能对读者有所帮助。

1.最近公共祖先(LCA)

1
2
3
4
5
6
7
int lca(int x,int y){
while(top[x] != top[y]){
if(dep[top[x]] > dep[top[y]])x = par[top[x]];
else y = par[top[y]];
}
return dep[x]<dep[y] ? x : y;
}

2.【模板】树链剖分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
inline void pd(int rt,int l,int r){
int lson = ls(rt),rson = rs(rt);
ll x = lazy(rt);
int mid = l+r>>1;
val(lson) += (mid-l+1)*x;
val(rson) += (r-mid)*x;
lazy(lson) += x;
lazy(rson) += x;
lazy(rt) = 0;
}
void p_upd(int pos,int x,int rt,int l,int r){
if(l == r)val(rt) += x;
else{
int mid = l+r>>1;
if(pos <= mid)p_upd(pos,x,ls(rt),l,mid);
else p_upd(pos,x,rs(rt),mid+1,r);
upar(rt);
}
}
ll p_ask(int pos,int rt,int l,int r){
if(l == r)return val(rt);
else{
int mid = l+r>>1;
if(pos <= mid)return p_ask(pos,ls(rt),l,mid);
else return p_ask(pos,rs(rt),mid+1,r);
}
}
ll r_ask(int cl,int cr,int rt = 1,int l = 1,int r = n){
if(cl <= l && r <= cr)return val(rt);
else{
int mid = l+r>>1;
ll ans = 0;
if(lazy(rt))pd(rt,l,r);
if(cl <= mid)ans += r_ask(cl,cr,ls(rt),l,mid);
if(cr > mid)ans += r_ask(cl,cr,rs(rt),mid+1,r);
return ans;
}
}
void r_upd(int ul,int ur,int x,int rt = 1,int l = 1,int r = n){
if(ul <= l && r <= ur)lazy(rt) += x,val(rt) += (r-l+1)*x;
else{
int mid = l+r>>1;
if(lazy(rt))pd(rt,l,r);
if(ul <= mid)r_upd(ul,ur,x,ls(rt),l,mid);
if(ur > mid)r_upd(ul,ur,x,rs(rt),mid+1,r);
upar(rt);
}
}
void son_upd(int u,int x){
r_upd(rk[u],rk[u]+size[u]-1,x,1,1,n);
}
void path_upd(int x,int y,int val){
while(top[x] != top[y]){
if(dep[top[x]] > dep[top[y]])r_upd(rk[top[x]],rk[x],val),x = par[top[x]];
else r_upd(rk[top[y]],rk[y],val),y = par[top[y]];
}
if(dep[x] > dep[y])swap(x,y);
r_upd(rk[x],rk[y],val);
}
ll son_ask(int u){
return r_ask(rk[u],rk[u]+size[u]-1,1,1,n);
}
ll path_ask(int x,int y){
ll ans = 0;
while(top[x] != top[y]){
if(dep[top[x]] > dep[top[y]])ans += r_ask(rk[top[x]],rk[x]), x = par[top[x]];
else ans += r_ask(rk[top[y]],rk[y]), y = par[top[y]];
}
if(dep[x] > dep[y])swap(x,y);
ans += r_ask(rk[x],rk[y]);
return ans;
}

3.【国家集训队】旅游

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
inline void upar(int rt){
val(rt) = val(ls(rt)) + val(rs(rt));
mi(rt) = min(mi(ls(rt)),mi(rs(rt)));
ma(rt) = max(ma(ls(rt)),ma(rs(rt)));
}
inline void toRev(int rt){
val(rt) = -val(rt);
swap(mi(rt),ma(rt));
mi(rt) = -mi(rt);
ma(rt) = -ma(rt);
}
void pushD(int rt){
toRev(ls(rt));
toRev(rs(rt));
lazy[ls(rt)] ^= 1;
lazy[rs(rt)] ^= 1;
lazy[rt] = 0;
return;
}
void p_upd(int pos,int x,int rt = 1,int l = 1,int r = n){
if(l == r)val(rt) = mi(rt) = ma(rt) = x;
else{
int mid = l+r>>1;
if(lazy[rt])pushD(rt);
if(pos <= mid)p_upd(pos,x,ls(rt),l,mid);
else p_upd(pos,x,rs(rt),mid+1,r);
upar(rt);
}
}

void r_upd(int ul,int ur,int rt = 1,int l = 1,int r = n){
if(ul <= l && r <= ur){
toRev(rt);
lazy[rt] ^= 1;
}
else{
int mid = l+r>>1;
if(lazy[rt])pushD(rt);
if(ul <= mid)r_upd(ul,ur,ls(rt),l,mid);
if(ur > mid)r_upd(ul,ur,rs(rt),mid+1,r);
upar(rt); //forget
}
}
int r_ask_sum(int cl,int cr,int rt = 1,int l = 1,int r = n){
if(cl <= l && r <= cr)return val(rt);
else{
int ans = 0;
int mid = l+r>>1;
if(lazy[rt])pushD(rt);
if(cl <= mid)ans += r_ask_sum(cl,cr,ls(rt),l,mid);
if(cr > mid)ans += r_ask_sum(cl,cr,rs(rt),mid+1,r);
return ans;
}
}
int r_ask_mi(int cl,int cr,int rt = 1,int l = 1,int r = n){
if(cl <= l && r <= cr)return mi(rt);
else{
int ans = 1005;
int mid = l+r>>1;
if(lazy[rt])pushD(rt);
if(cl <= mid)ans = min(ans, r_ask_mi(cl,cr,ls(rt),l,mid));
if(cr > mid)ans = min(ans, r_ask_mi(cl,cr,rs(rt),mid+1,r));
return ans;
}
}
int r_ask_ma(int cl,int cr,int rt = 1,int l = 1,int r = n){
if(cl <= l && r <= cr)return ma(rt);
else{
int ans = -1005;
int mid = l+r>>1;
if(lazy[rt])pushD(rt);
if(cl <= mid)ans = max(ans, r_ask_ma(cl,cr,ls(rt),l,mid));
if(cr > mid)ans = max(ans, r_ask_ma(cl,cr,rs(rt),mid+1,r));
return ans;
}
}
int path_sum(int x,int y){
int ans = 0;
while(top[x] != top[y]){
if(dep[top[x]] > dep[top[y]])ans += r_ask_sum(rk[top[x]],rk[x]), x = par[top[x]];
else ans += r_ask_sum(rk[top[y]],rk[y]), y = par[top[y]];
}
if(dep[x] > dep[y])swap(x,y);
if(x != y)ans += r_ask_sum(rk[x]+1,rk[y]);
return ans;
}
int path_ma(int x,int y){
int ans = -1005;
while(top[x] != top[y]){
if(dep[top[x]] > dep[top[y]])ans = max(ans, r_ask_ma(rk[top[x]],rk[x])), x = par[top[x]];
else ans = max(ans, r_ask_ma(rk[top[y]],rk[y])), y = par[top[y]];
}
if(dep[x] > dep[y])swap(x,y);
if(x != y)ans = max(ans, r_ask_ma(rk[x]+1,rk[y]));
return ans;
}
int path_mi(int x,int y){
int ans = 1005;
while(top[x] != top[y]){
if(dep[top[x]] > dep[top[y]])ans = min(ans, r_ask_mi(rk[top[x]],rk[x])), x = par[top[x]];
else ans = min(ans, r_ask_mi(rk[top[y]],rk[y])), y = par[top[y]];
}
if(dep[x] > dep[y])swap(x,y);
if(x != y)ans = min(ans, r_ask_mi(rk[x]+1,rk[y]));
return ans;
}
void path_upd(int x,int y){
while(top[x] != top[y]){
if(dep[top[x]] > dep[top[y]])r_upd(rk[top[x]],rk[x]), x = par[top[x]];
else r_upd(rk[top[y]],rk[y]), y = par[top[y]];
}
if(dep[x] > dep[y])swap(x,y);
if(x != y)r_upd(rk[x]+1,rk[y]);
}