从零开始掌握线段树大法

Brilliant11001

2023-12-30 10:11:57

Theory

# 简介: 线段树($\texttt {Segment Tree}$) 是一种高级数据结构,是一种**基于分治思想的二叉树结构**,主要用来处理**区间问题**。它可以在 $O(\log n)$ 的时间复杂度内维护序列中**满足结合律**的变量,例如:$max$,$min$,$\sum$ 和 $xor$。总的说来还是一个功能非常强大的数据结构,也有许多拓展。 下面就来逐步揭开线段树的神秘面纱。 ------------ ## 线段树的基本知识及建树 线段树的本质是一棵二叉树,它有以下特性: 1. 线段树的每个节点都代表一个区间; 2. 线段树具有唯一的根结点,代表的区间是整个统计范围,如 $[1,N]$; 3. 线段树的每个叶子节点都代表一个长度为 $1$ 的元区间(元线段)$[x,x]$; 4. 对于**每个内部节点 $[l,r]$,它的左子结点是 $[l,mid]$,右子节点是 $[mid + 1,r]$,其中 $mid = l + r >> 1$(向下取整),这样也保证了线段树对区间包括地不重不漏。** ![](https://cdn.luogu.com.cn/upload/image_hosting/lvu0w46l.png) 其实还是非常的形象,我们可以发现,**除去最后一层,整棵线段树是一棵满二叉树**,树的深度是 $O(\log n)$,因此我们可以按照与二叉堆类似的 **“父子 $2$ 倍”节点编号的方法**。 1. 根节点编号为 $1$; 2. 编号为 $x$ 的节点的左子结点编号为 $x * 2$,右子节点编号为 $x * 2 + 1$。 这样一来,就可以用结构体来存储树中的信息。**这里要注意:在理想情况下,$N$ 个节点的满二叉树有 $N + N / 2 + N / 4 + \cdots + 2 + 1 = 2N - 1$ 个节点。因为在这种存储方式下,最后一行会产生空余,最后一行会有 $2N$ 个空间,所以保存线段树的数组长度要不小于 $4N$,才能保证不会越界。** ```cpp struct SegmentTree{ int l, r, data; }tr[N]; ``` 以下都以维护区间最大值为例。 ### 建树 其实根据上面的图,思路已经很明了了:**递归建树!** 代码: ```cpp void build(int p, int l, int r) { tr[p].l = l, tr[u].r = r; if(l == r) { tr[p].maxx = a[l]; return ; //叶子结点 } int mid = l + r >> 1; build(p << 1, l , mid); //建左子树 build(p << 1 | 1, mid + 1, r); //建右子树 tr[p].maxx = max(tr[p << 1].maxx, tr[p << 1 | 1].maxx);//整合子节点的信息 } ``` ### 单点修改 单点修改是一条类似 $\texttt{"C x v"}$ 的指令,表示把 $A[x]$ 的值修改为 $v$。 在线段树中,根节点(编号为 $1$ 的节点)是所有指令的入口。所以思路就出来了:从根节点开始,递归找到代表 $[x,x]$ 区间的叶子节点,并把其值更新。由于递归时会先到达底端,再向上回溯,所以我们可以在回溯时顺便在父节点整合子节点的信息。时间复杂度为 $O(\log n)$。 代码: ```cpp void change(int p, int x, int val) { if(l(p) == r(p)) { maxx(p) = val; return ; } int mid = l(p) + r(p) >> 1; if(x <= mid) change(ls(p), x, val); //递归左儿子 else change(rs(p), x, val); //递归右儿子 pushup(p); //整合信息 } change(1, x, v); //调用入口 ``` ### 区间查询 单点修改是一条类似 $\texttt{"Q l r"}$ 的指令,例如查询序列 $A$ 在区间 $[l,r]$ 上的最大值,即$\max_{l\le i\le r}A[i]$。同样的,我们只需要从根节点开始,递归执行以下过程即可: 1. 若 $[l,r]$ 完全覆盖了当前节点代表的区间,就可以直接返回该节点的信息。 2. 若左儿子与 $[l,r]$ 有交集,则递归到左儿子。 3. 若右儿子与 $[l,r]$ 有交集,则递归到右儿子。 代码: ```cpp int query(int p, int l, int r) { if(l <= l(p) && r >= r(p)) return maxx(p); int res = -(1 << 30); int mid = l(p) + r(p) >> 1; if(l <= mid) res = max(res, query(ls(p), l, r)); if(r > mid) res = max(res, query(rs(p), l, r)); return res; } ``` 该查询过程会把**询问区间在线段树上分成 $O(\log n)$ 个节点**,所以时间复杂度为 $O(\log n)$。 为什们呢?我们不妨分类讨论一下: $1. \space\space l \le p_l\le p_r \le r$,则此时完全覆盖了当前节点,直接返回。 $2. \space\space p_l\le l\le p_r\le r$,此时只有 $l$ 处于节点之内,则: $\space\space\space\space$ (1)$\space\space l > mid$,只会递归右子树 $\space\space\space\space$ (2)$\space\space l \le mid$,虽然递归两棵子树,但是右儿子会在递归后直接返回。 $3. \space\space l\le p_l\le r\le p_r$,即只有 $r$ 处于节点之内,与情况 $2$ 类似。 $4. \space\space p_l\le l\le r\le p_r$,即 $l$ 和 $r$ 都位于节点之内。 $\space\space\space\space$ (1)$\space\space l,r$ 都位于 $mid$ 的一侧,只会递归一棵子树。 $\space\space\space\space$ (2)$\space\space l,r$ 分别位于 $mid$ 的两侧,递归左右两棵子树。 也就是说,只有情况 $4(2)$ 会真正产生对左右两棵子树的递归。这种情况至多发生一次,之后在子结点上就会变成情况 $2$ 或 $3$。因此,上述查询过程的时间复杂度为 $O(2\log n) = O(\log n)$。从宏观上理解,相当于 $l,r$ 两个端点分别在线段树上划分一条递归访问路径,情况 $4(2)$ 在两条路径与从下往上的第一次交会处产生。 ### 区间修改 这种情况就要比单点修改棘手一点,毕竟要比人家多改很多点,但时间复杂度还要在一个数量级,确实不简单。 试想一下,某个非叶子节点被修改区间 $[l,r]$ 完全覆盖,若直接从上到下传导修改信息,那么以该节点为根的子树就要全部被修改,时间按复杂度 $O(n)$,这是我们不能接受的。 再试想,对于一次区间修改如果我们发现某个节点 $p$ 所代表的区间被查询区间 $[l,r]$ 完全覆盖,并将此子树 $p$ 全部更新。但是在之后的查询操作中却完全没有用到 $[l,r]$ 的子区间的信息,那么更新整棵子树就是徒劳。 那怎么办呢?这时候就需要引入一个新的东西:**“延迟标记”**,又叫做 $\texttt{lazy tag}$,来标识“该节点曾经被修改,但其子节点尚未被更新”。 如果在后续的指令中,需要从节点 $p$ 、向下递归,我们再检查 $p$ 是否有标记。若有标记,就先把 $p$ 的子节点更新,给两个子节点打上标记,再把 $p$ 的标记消除。 这样一来,除了在修改指令中直接划分的 $O(\log n)$ 个节点之外,对任意节点修改都延迟到“在后续操作中递归进入它的父节点时”在执行。每条查询或修改操作的时间复杂度都降低到了 $O(\log n)$。 #### [【模板】线段树 1](https://www.luogu.com.cn/problem/P3372) 代码: ```cpp #include <iostream> using namespace std; const int N = 100010; typedef long long ll; struct SegmentTree { int l, r; ll sum, add; #define l(x) tr[x].l #define r(x) tr[x].r #define sum(x) tr[x].sum #define add(x) tr[x].add }tr[N * 4]; int n, m; ll a[N]; void build(int p, int l, int r) { l(p) = l, r(p) = r; if(l == r) { sum(p) = a[l]; return ; } int mid = l + r >> 1; build(p * 2, l, mid); build(p * 2 + 1, mid + 1, r); sum(p) = sum(p * 2) + sum(p * 2 + 1); } void spread(int p) { if(add(p)) { sum(p * 2) += add(p) * (r(p * 2) - l(p * 2) + 1); sum(p * 2 + 1) += add(p) * (r(p * 2 + 1) - l(p * 2 + 1) + 1); //更新子节点 add(p * 2) += add(p); add(p * 2 + 1) += add(p); //下传标记 add(p) = 0; //消除父节点的标记 } } void change(int p, int l, int r, ll val) { if(l <= l(p) && r >= r(p)) { sum(p) += val * (r(p) - l(p) + 1), add(p) += val; return ; } //完全包含 spread(p); //下传懒标记 int mid = l(p) + r(p) >> 1; if(l <= mid) change(p * 2, l, r, val); if(r > mid) change(p * 2 + 1, l, r, val); sum(p) = sum(p * 2) + sum(p * 2 + 1); } ll query(int p, int l, int r) { if(l <= l(p) && r >= r(p)) return sum(p); spread(p); //查询也要下传懒标记 ll res = 0; int mid = l(p) + r(p) >> 1; if(l <= mid) res += query(p * 2, l, r); if(r > mid) res += query(p * 2 + 1, l, r); return res; } int main() { scanf("%d%d", &n, &m); for(int i = 1; i <= n; i++) scanf("%lld", &a[i]); build(1, 1, n); int op, x, y; ll k; while(m--) { scanf("%d%d%d", &op, &x, &y); if(op == 1) { scanf("%lld", &k); change(1, x, y, k); } else { printf("%lld\n", query(1, x, y)); } } return 0; } ``` #### [【模板】线段树 2](https://www.luogu.com.cn/problem/P3372) 这道题要维护乘和加两个懒标记。 **注意:要先乘再加,并且再乘之后 add 的懒标记也要相对改变。** 代码: ```cpp #include <iostream> using namespace std; const int N = 100010; typedef long long ll; int n, m, mod; ll a[N]; struct SegmentTree{ int l, r; ll sum, add, mul; #define l(x) tr[x].l #define r(x) tr[x].r #define sum(x) tr[x].sum #define add(x) tr[x].add #define mul(x) tr[x].mul }tr[N * 4]; int ls(int p) {return p * 2;} int rs(int p) {return p * 2 + 1;} void pushup(int p) {sum(p) = (sum(ls(p)) + sum(rs(p))) % mod;} void build(int p, int l, int r) { l(p) = l, r(p) = r, mul(p) = 1; if(l == r) { sum(p) = a[l] % mod; return ; } int mid = l + r >> 1; build(ls(p), l, mid); build(rs(p), mid + 1, r); pushup(p); } void spread(int p) { sum(ls(p)) = (sum(ls(p)) * mul(p) % mod + (r(ls(p)) - l(ls(p)) + 1) * add(p) % mod) % mod; sum(rs(p)) = (sum(rs(p)) * mul(p) % mod + (r(rs(p)) - l(rs(p)) + 1) * add(p) % mod) % mod; //更新sum,注意先乘再加 mul(ls(p)) = mul(ls(p)) * mul(p) % mod; mul(rs(p)) = mul(rs(p)) * mul(p) % mod; add(ls(p)) = (add(p) + add(ls(p)) * mul(p) % mod) % mod; add(rs(p)) = (add(p) + add(rs(p)) * mul(p) % mod) % mod; //add 懒标记也要变 mul(p) = 1, add(p) = 0; //消除懒标记 } void change1(int p, int l, int r, ll val) { if(l <= l(p) && r >= r(p)) { sum(p) = sum(p) * val % mod; mul(p) = mul(p) * val % mod; add(p) = add(p) * val % mod; //add 懒标记也要改变 return ; } spread(p); int mid = l(p) + r(p) >> 1; if(l <= mid) change1(ls(p), l, r, val); if(r > mid) change1(rs(p), l, r, val); pushup(p); } void change2(int p, int l, int r, ll val) { if(l <= l(p) && r >= r(p)) { sum(p) = (sum(p) + (r(p) - l(p) + 1) * val % mod) % mod; add(p) = (add(p) + val) % mod; return ; } spread(p); int mid = l(p) + r(p) >> 1; if(l <= mid) change2(ls(p), l, r, val); if(r > mid) change2(rs(p), l, r, val); pushup(p); } ll query(int p, int l, int r) { if(l <= l(p) && r >= r(p)) return sum(p); spread(p); int mid = l(p) + r(p) >> 1; ll res = 0; if(l <= mid) res = (res + query(ls(p), l, r)) % mod; if(r > mid) res = (res + query(rs(p), l, r)) % mod; return res; } int main() { scanf("%d%d%d", &n, &m, &mod); for(int i = 1; i <= n; i++) scanf("%lld", &a[i]); build(1, 1, n); int op, x, y; ll k; while(m--) { scanf("%d%d%d", &op, &x, &y); if(op == 1) { scanf("%lld", &k); change1(1, x, y, k); } else if(op == 2) { scanf("%lld", &k); change2(1, x, y, k); } else { printf("%lld\n", query(1, x, y)); } } return 0; } ``` #### [P1253 扶苏的问题](https://www.luogu.com.cn/problem/P1253) 要多维护一个覆盖的懒标记。 **注意:覆盖某节点时该节点的 add 懒标记要清零(人都没了还更新啥)。** 代码: ```cpp #include <iostream> #include <cstdio> #include <cstdlib> using namespace std; const int N = 1000010; typedef long long ll; const ll inf = 0x3f3f3f3f3f3f3f3f; struct SegmentTree{ int l, r; ll add, cover, maxx; #define l(x) tr[x].l #define r(x) tr[x].r #define maxx(x) tr[x].maxx #define add(x) tr[x].add #define cover(x) tr[x].cover }tr[N * 4]; int n, q; ll a[N]; int ls(int p) {return p << 1;} int rs(int p) {return p << 1 | 1;} void pushup(int p) {maxx(p) = max(maxx(ls(p)), maxx(rs(p)));} void spread(int p){ if(cover(p) != -inf && l(p) != r(p)){ maxx(ls(p)) = cover(p); maxx(rs(p)) = cover(p); cover(ls(p)) = cover(p); cover(rs(p)) = cover(p); add(ls(p)) = add(rs(p)) = 0; cover(p) = -inf; } if(add(p) && l(p) != r(p)) { maxx(ls(p)) += add(p); maxx(rs(p)) += add(p); add(ls(p)) += add(p); add(rs(p)) += add(p); add(p) = 0; } } void build(int p, int l, int r) { l(p) = l, r(p) = r, cover(p) = -inf; if(l == r) { maxx(p) = a[l]; return ; } int mid = l + r >> 1; build(ls(p), l, mid); build(rs(p), mid + 1, r); pushup(p); } void change(int p, int l, int r, ll val) { if(l <= l(p) && r >= r(p)) { maxx(p) = val; add(p) = 0; cover(p) = val; return ; } spread(p); int mid = l(p) + r(p) >> 1; if(l <= mid) change(ls(p), l, r, val); if(r > mid) change(rs(p), l, r, val); pushup(p); } void pluss(int p, int l, int r, ll val) { if(l <= l(p) && r >= r(p)) { maxx(p) += val; add(p) += val; return ; } spread(p); int mid = l(p) + r(p) >> 1; if(l <= mid) pluss(ls(p), l, r, val); if(r > mid) pluss(rs(p), l, r, val); pushup(p); } ll query(int p, int l, int r) { if(l <= l(p) && r >= r(p)) return maxx(p); spread(p); ll res = -inf; int mid = l(p) + r(p) >> 1; if(l <= mid) res = max(res, query(ls(p), l, r)); if(r > mid) res = max(res, query(rs(p), l, r)); return res; } int main() { scanf("%d%d", &n, &q); for(int i = 1; i <= n; i++) scanf("%lld", &a[i]); build(1, 1, n); int op, x, y; ll k; while(q--) { scanf("%d%d%d", &op, &x, &y); if(op == 1) { scanf("%lld", &k); change(1, x, y, k); } else if(op == 2) { scanf("%lld", &k); pluss(1, x, y, k); } else { printf("%lld\n", query(1, x, y)); } } return 0; } ``` #### [SP1716](https://www.luogu.com.cn/problem/SP1716) 单点修改 + 区间最大子段和。 因为父节点的和最大的子段可能会跨区间,所以不能直接维护最大子段和,这时候就需要分类讨论最大子段和的取值情况。 1. 父节点的最大子段和在左儿子上。 ![asdsajdfhiujhkja.png](https://i.loli.net/2020/03/03/MbQhGW7ruBJPYXk.png) 2. 父节点的最大子段和在右儿子上。 ![asdasajdfhiujhkja.png](https://i.loli.net/2020/03/03/EO9GtVwImgJ163u.png) 3. 跨节点。 ![aasdasajdfhiujhkja.png](https://i.loli.net/2020/03/03/fpndbVqamoOIXtK.png) 由以上三个图可知,父节点的最大子段和就是**左儿子的最大子段和**、**右儿子的最大子段和**和**左儿子的最大后缀和 + 左儿子的最大前缀和三个中的最大值**,所以我们可以再维护三个值:**区间和,区间最大前缀和区间最大后缀。** 首先区间和很好维护,那剩下两个怎么办呢? 还是分类讨论取值情况。(以最大前缀为例,最大后缀也是同理) 1. 不跨区间 ![aaasajdfhiujhkja.png](https://i.loli.net/2020/03/03/b2BmrAucJ8jLdER.png) 2. 跨区间 ![asdssssajdfhiujhkja.png](https://i.loli.net/2020/03/03/8TJ4RkfYzWdVSPX.png) 所以最大前缀和就是**左儿子的最大前缀和**和**左儿子区间和 + 右儿子的最大前缀和的最大值**。 剩下的就是线段树模板了: ```cpp #include <iostream> using namespace std; const int N = 500010; int n, m; int a[N]; struct SegmentTree { int l, r; int sum; int lmax, rmax, tmax; #define l(x) tr[x].l #define r(x) tr[x].r #define sum(x) tr[x].sum #define lmax(x) tr[x].lmax #define rmax(x) tr[x].rmax #define tmax(x) tr[x].tmax }tr[N << 2]; inline int ls(int p) {return p << 1;} inline int rs(int p) {return p << 1 | 1;} inline void pushup(int p) { sum(p) = sum(ls(p)) + sum(rs(p)); lmax(p) = max(lmax(ls(p)), sum(ls(p)) + lmax(rs(p))); rmax(p) = max(rmax(rs(p)), sum(rs(p)) + rmax(ls(p))); tmax(p) = max(max(tmax(ls(p)), tmax(rs(p))), rmax(ls(p)) + lmax(rs(p))); } void build(int p, int l, int r) { l(p) = l, r(p) = r; if(l == r) { tmax(p) = lmax(p) = rmax(p) = sum(p) = a[l]; return ; } int mid = l + r >> 1; build(ls(p), l, mid); build(rs(p), mid + 1, r); pushup(p); } void modify(int p, int x, int val) { if(l(p) == r(p)) { tmax(p) = lmax(p) = rmax(p) = sum(p) = val; return ; } int mid = l(p) + r(p) >> 1; if(x <= mid) modify(ls(p), x, val); else modify(rs(p), x, val); pushup(p); } SegmentTree query(int p, int l, int r) { if(l <= l(p) && r >= r(p)) return tr[p]; int mid = l(p) + r(p) >> 1; if(r <= mid) return query(ls(p), l, r); if(l > mid) return query(rs(p), l, r); SegmentTree res, res1, res2; res1 = query(ls(p), l, r); res2 = query(rs(p), l, r); res.sum = res1.sum + res2.sum; res.lmax = max(res1.lmax, res1.sum + res2.lmax); res.rmax = max(res2.rmax, res2.sum + res1.rmax); res.tmax = max(max(res1.tmax, res2.tmax), res1.rmax + res2.lmax); return res; } int main() { scanf("%d%d", &n, &m); for(int i = 1; i <= n; i++) scanf("%d", &a[i]); build(1, 1, n); char op[2]; int x, y; while(m--) { scanf("%s%d%d", op, &x, &y); if(op[0] == '2') { modify(1, x, y); } else { if(x > y) swap(x, y); printf("%d\n", query(1, x, y).tmax); } } return 0; } ```