浅谈李超线段树及其应用

tommymio

2020-02-12 12:48:13

Algo. & Theory

在我的博客食用效果更佳

算法思想及实现

一种高级数据结构,最经典的应用就是维护一个二维平面直角坐标系,支持动态插入线段,询问与直线 x=x_0 相交的已插入线段中交点 y 的最大/最小值,即当 x=x_0,求\max\{k_ix+b_i\}\min\{k_ix+b_i\}。两种情况本质上是一样的,那我们现在就来讨论一下它的实现。

解决这个问题,最暴力的做法就是当 x=x_0 时,将 n 条线段全部遍历一遍,时间复杂度为 O(n),这似乎就是这个问题的瓶颈了,那么如何优化时间复杂度呢?这种做法给了我们启发,这个问题的瓶颈就在于需要遍历的集合过大,那么,如果找到一种方法能够有效减小集合大小,排除不可能成为最优解的,是否就能够降低时间复杂度呢?答案是肯定的。李超线段树使用的正是这种思想。

李超线段树是一种特殊的线段树,它的特殊之点在于每个区间只记录在当前区间可能成为最优解的线段,即该线段在当前区间的某个取值上有最优解。可以证明当查询的时候,只需要遍历 \log n 个节点,即可找出最优解,故单次查询时间复杂度为 O(\log n)。修改时通过替换和下放线段,可以达到 O(\log n) 的时间复杂度,注意此处的时间复杂度为全局修改的时间复杂度。特别地,对于区间修改的时间复杂度为 O(\log^2 n),读者可以自行证明。

维护区间内可能成为最优解的线段,且保证时间复杂度就成了这个算法的核心问题。对于这一问题,我们可以使用斜率和线段在当前区间中点上的取值,进行分类讨论。这里就以求最大值为例进行分析。

  1. 若当前区间内没有任何线段,则直接将新线段放入当前区间。

  2. 若当前区间 [L,R] 内有旧线段,且新线段的斜率大于旧线段的斜率,设当前区间中点为 mid:

    • 若旧线段在 mid 上的取值大于新线段在 mid 上的取值,则新线段在 [L,mid] 上的取值一定不比旧线段更优,但新线段在 [mid+1,R] 上的取值可能更优,将新线段下放至[mid+1,R]区间并递归更新答案。

    • 若旧线段在mid上的取值小于新线段在mid上的取值,则旧线段在[mid+1,R]上的取值一定不比新线段更优,但旧线段在[L,mid]上的取值可能更优,用新线段替换该当前节点的旧线段,将旧线段下放至 [L,mid] 区间并递归更新答案。

  3. 若当前区间 [L,R] 内有旧线段,且新线段的斜率小于旧线段的斜率,设当前区间中点为 mid :

    • 若旧线段在 mid 上的取值小于新线段在 mid 上的取值,则旧线段在 [L,mid] 上的取值一定不比新线段更优,但旧线段在 [mid+1,R] 上的取值可能更优,用新线段替换当前节点的旧线段,将旧线段下放至 [mid+1,R] 区间并递归更新答案。
    • 若旧线段在mid上的取值大于新线段在mid上的取值,则新线段在[mid+1,R] 上的取值一定不比旧线段更优,但新线段在 [L,mid] 上的取值可能更优,将新线段下放至 [L,mid] 区间并递归更新答案。
  4. 若当前区间 [L,R] 内有旧线段,且新旧线段斜率相同,比较截距b,截距大的线段在 [L,R] 上的取值一定优于截距小的线段,直接用截距大的替换截距小的。

如何理解通过两条线段在mid上的取值和线段的斜率就能确定哪条线段更优?我们可以通过旋转来理解。

y_1 为旧线段,y_3 为新线段,mid 为当前区间 [L,R] 的中点,这里以 y_1 的斜率大于 y_3y_1mid 上的取值大于 y_3 为例。设线段 y_2 与线段 y_1 斜率相同,且在 mid 上的取值与 y_3mid 上的取值相同,可以发现 y_2[L,mid],[mid+1,R] 上的取值一定不比 y_1 更优。还可以发现,当我们将y_2向逆时针旋转\beta^{\circ}(0^{\circ}<\beta^{\circ}<\alpha^{\circ}) 时,y_2 的斜率逐渐变小,在这过程中 y_2y_1[L,mid] 有交点且 y_2[L,mid] 的取值整体变大,说明 y_2[L,mid] 区间内可能存在某个取值使 y_2 成为最优解;而在这过程中,y_2[mid+1,R] 上的取值整体变小,一定不会成为最优解。易证将 y_2 向逆时针旋转 \beta^{\circ}(0^{\circ}<\beta^{\circ}<\alpha^{\circ}) 得到的 y_2' 中,一定存在某个 y_2'y_3 完全相同,故 y_2 变化趋势即为 y_3 变化趋势。

另外,经过观察,可以发现分类讨论的框架与线段树很相似,可以直接放到线段树上维护,这就是李超线段树。

例题

Luogu P4254/P4097 Blue Mary开公司/Segment

这是江苏2008年的省选题/河北2013年的省选题(省选居然考板子题

除了一些细节需注意以外,这两题很明显是个李超线段树的裸题,就直接上 P4254 的代码了。

#include<cstdio>
#include<iostream>
#include<algorithm>
const int N=500005;
int tot=0;
char s[15];
int tag[4000005];
double k[100005],b[100005];
inline double calc(int i,int x) {return k[i]*(x-1)+b[i];}
void change(int p,int l,int r,int x) {
    if(l==r) {
        if(calc(tag[p],l)<calc(x,l)) tag[p]=x;
        return;
    }
    if(!tag[p]) {tag[p]=x;return;}
    else {
        int mid=l+r>>1;
        double y1=calc(tag[p],mid),y2=calc(x,mid);
        if(k[tag[p]]<k[x]) {
            if(y1<=y2) {change(p*2,l,mid,tag[p]);tag[p]=x;} 
            else {change(p*2+1,mid+1,r,x);}
        }
        else if(k[tag[p]]>k[x]) {
            if(y1<=y2) {change(p*2+1,mid+1,r,tag[p]);tag[p]=x;}
            else {change(p*2,l,mid,x);}
        }
        else if(b[tag[p]]<b[x]) {tag[p]=x;}
    } 
}
double query(int p,int l,int r,int x) {
    if(l==r) return calc(tag[p],l);
    double res=calc(tag[p],x);
    int mid=l+r>>1;
    if(x<=mid) res=std::max(res,query(p*2,l,mid,x));
    else res=std::max(res,query(p*2+1,mid+1,r,x));
    return res;
}
int main() {
    int n;
    std::cin>>n;
    for(register int i=1;i<=n;++i) {
        std::cin>>s;
        if(s[0]=='Q') {
            int t;
            std::cin>>t;
            if(tot==0) printf("0\n");
            else printf("%d\n",(int)query(1,1,N,t)/100);
        }
        else {
            double s,p;++tot;
            std::cin>>b[tot]>>k[tot];
            change(1,1,N,tot);
        }
    }
    return 0;
}

Luogu P4069 [SDOI2016]游戏

题目简述

强行上树差评

这题挺简单的,树剖剖一剖,将树上问题转化为序列问题,就可以用李超线段树来做了。

但需要注意,李超线段树只能维护k值确定的线段,所以我们要将式子适当变形,然后直接在李超线段树上维护区间最值就可以了。

#include<cstdio>
#include<climits>
#include<cstring>
#include<algorithm>
const long long inf=LLONG_MAX;
int cnt=0,num=0,tot=0;
int tag[800005];
long long minn[800005],b[200005],k[200005],dis[100005];
int h[100005],to[200005],ver[200005],w[200005];
int size[100005],d[100005],fa[100005];
int seg[100005],rev[100005],top[100005],son[100005];
inline int read() {
    register int x=0,f=1;register char s=getchar();
    while(s>'9'||s<'0') {if(s=='-') f=-1;s=getchar();}
    while(s>='0'&&s<='9') {x=x*10+s-'0';s=getchar();}
    return x*f;
}
inline void add(int x,int y,int z) {
    to[++cnt]=y;
    ver[cnt]=h[x];
    w[cnt]=z;
    h[x]=cnt;
}
inline long long val(int id,int x) {return k[id]*dis[rev[x]]+b[id];}
void change(int p,int l,int r,int L,int R,int id) {//[L,R]查询区间
    int mid=l+r>>1;
    if(l==r&&val(id,l)<val(tag[p],l)) {tag[p]=id;minn[p]=val(id,l);return;}
    else if(l==r&&val(id,l)>=val(tag[p],l)) return;
    else if(L<=l&&R>=r) {
        if(val(id,l)>=val(tag[p],l)&&val(id,r)>=val(tag[p],r)) return;
        else if(val(id,l)<val(tag[p],l)&&val(id,r)<val(tag[p],r)) tag[p]=id;
        else {
            if(k[id]<k[tag[p]]) {
                if(val(id,mid)<=val(tag[p],mid)) {change(p*2,l,mid,L,R,tag[p]);tag[p]=id;}
                else {change(p*2+1,mid+1,r,L,R,id);}
            }
            else {
                if(val(id,mid)<=val(tag[p],mid)) {change(p*2+1,mid+1,r,L,R,tag[p]);tag[p]=id;}
                else {change(p*2,l,mid,L,R,id);}
            }   
        }
        minn[p]=std::min(minn[p],std::min(val(tag[p],l),val(tag[p],r)));
        minn[p]=std::min(minn[p],std::min(minn[p*2],minn[p*2+1]));
        return;
    }
    if(L<=mid) change(p*2,l,mid,L,R,id);
    if(R>mid) change(p*2+1,mid+1,r,L,R,id);
    minn[p]=std::min(minn[p],std::min(minn[p*2],minn[p*2+1]));
}
long long ask(int p,int l,int r,int L,int R) {
    if(L<=l&&R>=r) return minn[p];
    long long res=std::min(val(tag[p],std::max(l,L)),val(tag[p],std::min(r,R)));
    int mid=l+r>>1;
    if(L<=mid) res=std::min(res,ask(p*2,l,mid,L,R));
    if(R>mid) res=std::min(res,ask(p*2+1,mid+1,r,L,R));
    return res;
}
void dfs1(int x) {
    size[x]=1;
    for(register int i=h[x];i;i=ver[i]) {
        int y=to[i];
        if(y==fa[x]) continue;
        fa[y]=x;d[y]=d[x]+1;
        dis[y]=dis[x]+w[i];
        dfs1(y);
        if(size[son[x]]<size[y]) son[x]=y;
        size[x]+=size[y];
    }
}
void dfs2(int x,int t) {
    seg[x]=++num;rev[num]=x;top[x]=t;
    if(son[x]) dfs2(son[x],t);
    for(register int i=h[x];i;i=ver[i]) {
        int y=to[i];
        if(y==fa[x]||y==son[x]) continue;
        dfs2(y,y);
    }
}
inline int LCA(int x,int y) {
    while(top[x]!=top[y]) {d[top[x]]>d[top[y]]? x=fa[top[x]]:y=fa[top[y]];}
    return d[x]<d[y]? x:y;
}
inline void modify(int x,int y,int id) {
    while(top[x]!=top[y]) {
        if(d[top[x]]<d[top[y]]) std::swap(x,y);
        change(1,1,num,seg[top[x]],seg[x],id);
        x=fa[top[x]];
    }
    if(d[x]<d[y]) std::swap(x,y);
    change(1,1,num,seg[y],seg[x],id);
}
inline long long query(int x,int y) {
    long long res=inf;
    while(top[x]!=top[y]) {
        if(d[top[x]]<d[top[y]]) std::swap(x,y);
        res=std::min(res,ask(1,1,num,seg[top[x]],seg[x]));
        x=fa[top[x]];
    }
    if(d[x]>d[y]) std::swap(x,y);
    return std::min(res,ask(1,1,num,seg[x],seg[y]));
}
int main() {
    //freopen("game3.in","r",stdin);
    //freopen("game3.ans","w",stdout);
    int n=read(),m=read();
    k[0]=0;b[0]=inf;
    for(register int i=1;i<n;++i) {
        int x=read(),y=read(),z=read();
        add(x,y,z);add(y,x,z);
    }
    d[1]=1;
    dfs1(1);
    dfs2(1,1);
    for(register int i=1;i<=4*num;++i) minn[i]=inf;
    for(register int i=1;i<=m;++i) {
        int op=read(),s=read(),t=read();
        if(op==1) {
            long long a=read(),c=read();
            int u=LCA(s,t);
            k[++tot]=-a;b[tot]=a*dis[s]+c;
            modify(s,u,tot);
            k[++tot]=a;b[tot]=a*(dis[s]-2*dis[u])+c;
            modify(u,t,tot);
        }
        else {
            long long res=query(s,t);
            if(res==inf) printf("123456789123456789\n");
            else printf("%lld\n",res);
        }
    }
    return 0;
}

P4655 [CEOI2017]Building Bridges

题目简述

n 根柱子依次排列,每根柱子都有一个高度。第 i 根柱子的高度为 h_i

现在想要建造若干座桥,如果一座桥架在第 i 根柱子和第 j 根柱子之间,那么需要 (h_i-h_j)^2 的代价。

在造桥前,所有用不到的柱子都会被拆除,因为他们会干扰造桥进程。第 i 根柱子被拆除的代价为 w_i,注意 w_i 不一定非负,因为可能政府希望拆除某些柱子。

现在政府想要知道,通过桥梁把第 i 根柱子和第 j 根柱子连接的最小代价。注意桥梁不能在端点以外的任何地方相交。

分析&解答

很明显,如果没有最后一句话,这道题不可做。

但是,规定了不可能在端点外任何地方相交,一定是线性的,很明显想到DP。

f_i 为将第 1 根和第 i 根柱子相连的代价,则有状态转移方程:

f_i=\min\{f_j+\sum_{k=j+1}^{i-1} w_k + (h_i-h_j)^2\}

我们可以令 sum_i=\sum_{k=1}^{i} w_k,这样就可以将 \sum_{k=j+1}^{i-1} w_k 写作sum_{i-1}-sum_j,得:

f_i=\min\{f_j+sum_{i-1}-sum_j+ (h_i-h_j)^2\}

展开 (h_i-h_j)^2 得:

f_i=\min\{f_j+sum_{i-1}-sum_j+ h_i^2-2h_ih_j+h_j^2\}

整理可得:

f_i=\min\{(-2h_ih_j+f_j-sum_j+h_j^2)+(sum_{i-1}+ h_i^2)\}

看上去是不是很像斜率优化DP的样子?但是我不会

这既然是李超线段树的例题,当然要用李超线段树啦( 路人:我信了你的鬼逻辑,你能够动态维护这一堆乱七八糟的式子我把屏幕吃了

观察一下可以发现,若 i 确定,(sum_{i-1}+ h_i^2) 为常数项,可以忽略。

于是就变成了对于一个确定的 i ,求\min\{-2h_ih_j+f_j-sum_j+h_j^2\}

很明显的,这是一个类似于 \min\{kx+b\} 的式子,事实上它的值可以用李超 线段树求出来。

-2h_j 看作 k,则对于每一个 k,都有 f_j-sum_j+h_j^2 与之对应, 很明显这就是求 i-1 个一次函数在 h_i 上的取值最小,可以在转移过程中使 用李超线段树维护这个值,从而达到 O(n \log n) 的时间复杂度。

#include<cstdio>
#include<cstring>
#include<algorithm>
const long long inf=0x3f3f3f3f3f3f3f3f;
int tag[4000005];
long long h[100005],w[100005];
long long k[100005],b[100005];
long long sum[100005],f[100005];
inline int read() {
    register int x=0,f=1;register char s=getchar();
    while(s>'9'||s<'0') {if(s=='-') f=-1;s=getchar();}
    while(s>='0'&&s<='9') {x=x*10+s-'0';s=getchar();}
    return x*f;
}
inline long long val(int x,int id) {return k[id]*x+b[id];}
void add(int p,int l,int r,int id) {
    if(l==r) {
        if(val(l,id)<val(l,tag[p])) tag[p]=id;
        return;
    }
    int mid=l+r>>1;

    if(k[id]<k[tag[p]]) {
        if(val(mid,id)<=val(mid,tag[p])) {add(p*2,l,mid,tag[p]);tag[p]=id;}
        else {add(p*2+1,mid+1,r,id);}
    }
    else if(k[id]>k[tag[p]]) {
        if(val(mid,id)<=val(mid,tag[p])) {add(p*2+1,mid+1,r,tag[p]);tag[p]=id;}
        else {add(p*2,l,mid,id);}
    }
    else if(b[id]<b[tag[p]]) {
        tag[p]=id;
        return;
    }
}
long long ask(int p,int l,int r,int x) {
    if(l==r) return val(x,tag[p]);
    int mid=l+r>>1;
    long long res=val(x,tag[p]);
    if(x<=mid) return std::min(res,ask(p*2,l,mid,x));
    else return std::min(res,ask(p*2+1,mid+1,r,x));
}
int main() {
    int n=read();
    for(register int i=1;i<=n;++i) h[i]=read();
    for(register int i=1;i<=n;++i) w[i]=read();
    for(register int i=1;i<=n;++i) sum[i]=sum[i-1]+w[i];
    f[1]=0;b[0]=inf;
    k[1]=-2*h[1];b[1]=h[1]*h[1]-sum[1]+f[1];
    add(1,0,1e6,1);
    for(register int i=2;i<=n;++i) {
        f[i]=ask(1,0,1e6,h[i])+sum[i-1]+h[i]*h[i];
        k[i]=-2*h[i];b[i]=h[i]*h[i]-sum[i]+f[i];
        add(1,0,1e6,i);
    }
    printf("%lld\n",f[n]);
    return 0;
}

CF932F Escape Through Leaf

PS:应各位dalao的要求加上了李超线段树合并这一内容(

题面说的很清楚了,这题想让我们求树上每个点 x 到 树上任一叶子节点 y 的最小代价。

f_{x,y} 为从 x 号节点到达 i 号叶子节点的最小费用。

f_{x,i}=\min_{y \in subtree(x)} f_{y,i}+a_x*b_y 观察上式可以发现,当 $y$ 的值固定的时候,这个式子其实就是 $kx+b$,于是可以使用李超线段树来维护,同上题DP优化一样即可。 但是这是一个树上问题,我们考虑每次如何合并 $x$ 的子树对应的李超线段树,可以想到使用线段树合并来解决这个问题。 李超线段树的合并也很简单,边 $add$ 新合并的树的当前节点的 $tag$,然后递归合并即可,大体同线段树合并是一样的。~~限于篇幅,代码就不放了/kk~~ ```cpp #include<cstdio> #include<climits> typedef long long ll; const ll inf=LLONG_MAX; const int infL=-1e5,infR=1e5; int tot=0,num=0,cnt=0; ll b[300005],k[300005],res[300005]; int a[300005],c[300005]; int h[300005],to[600005],ver[600005]; int rt[300005],sonL[5000005],sonR[5000005],tag[5000005]; inline int read() { register int x=0,f=1;register char s=getchar(); while(s>'9'||s<'0') {if(s=='-') f=-1;s=getchar();} while(s>='0'&&s<='9') {x=x*10+s-'0';s=getchar();} return x*f; } inline void add(int x,int y) {to[++cnt]=y;ver[cnt]=h[x];h[x]=cnt;} inline ll min(ll x,ll y) {return x<y? x:y;} inline ll calc(ll x,int id) {return k[id]*x+b[id];} inline void change(int &p,int l,int r,int id) { if(!p) {tag[p=++tot]=id;return;} if(l==r) { if(calc(l,id)<calc(l,tag[p])) tag[p]=id; return; } int mid=l+r>>1; if(k[id]<k[tag[p]]) { if(calc(mid,id)<=calc(mid,tag[p])) {change(sonL[p],l,mid,tag[p]);tag[p]=id;} else {change(sonR[p],mid+1,r,id);} } else if(k[id]>k[tag[p]]) { if(calc(mid,id)<=calc(mid,tag[p])) {change(sonR[p],mid+1,r,tag[p]);tag[p]=id;} else {change(sonL[p],l,mid,id);} } else if(b[id]<b[tag[p]]) { tag[p]=id; return; } } inline ll ask(int p,int l,int r,int x) { if(!p) return inf; if(l==r) return calc(l,tag[p]); int mid=l+r>>1; ll minn=calc(x,tag[p]); if(x<=mid) return min(minn,ask(sonL[p],l,mid,x)); else return min(minn,ask(sonR[p],mid+1,r,x)); } inline int merge(int x,int y,int l,int r) { if(!x||!y) return x+y; if(l==r) return calc(l,tag[x])>calc(l,tag[y])? y:x; int mid=l+r>>1; sonL[x]=merge(sonL[x],sonL[y],l,mid); sonR[x]=merge(sonR[x],sonR[y],mid+1,r); change(x,l,r,tag[y]); return x; } inline void dfs(int x,int fa) { int du=0; for(register int i=h[x];i;i=ver[i]) { int y=to[i]; if(y==fa) continue; dfs(y,x);++du; rt[x]=merge(rt[x],rt[y],infL,infR); } k[++num]=c[x]; if(du) res[x]=b[num]=ask(rt[x],infL,infR,a[x]); else res[x]=b[num]=0; change(rt[x],infL,infR,num); } signed main() { int n=read(); for(register int i=1;i<=n;++i) a[i]=read(); for(register int i=1;i<=n;++i) c[i]=read(); for(register int i=1;i<n;++i) {int x=read(),y=read();add(x,y);add(y,x);} dfs(1,-1); for(register int i=1;i<=n;++i) printf("%lld ",res[i]); return 0; } ``` ## 总结 李超线段树是一种基于线段树维护一次函数在值域上维护最值的 $\log$ 级数据结构,多用于一些最优性问题,但主要用途是用来优化DP,绕开繁琐的斜率优化DP推导过程或维护凸壳的过程。 在DP优化中使用李超线段树能够显著减少码量,提升代码可读性,常数和线段树常数一样(这不废话么 由于李超线段树是一种线段树,因此线段树的性质李超线段树都具有,具体的一个体现就是 **CF932F** 的李超线段树合并。 缺点是对于一些斜率优化DP有所局限,不能完全适用。但是根据题目的性质合理使用,一定能给你的做题带来良好体验(大雾