重链剖分 37pts 求条

P3384 【模板】重链剖分/树链剖分

ARIS2_0 @ 2024-12-25 20:57:23

rt,有注释。

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define pii pair<int,int>
#define fi first
#define se second
#define mp make_pair
const int inf=1e16,maxn=5e5+10;
vector<int>v[maxn];//存图
int size[maxn],wc[maxn],dep[maxn],fa[maxn];
//size为子树大小,wc为重儿子编号,dep为深度,fa为父亲
bool isw[maxn];
//isw_i为第i个点是否是重节点
void dfs1(int x,int father){
    fa[x]=father;
    size[x]=1;
    int maxs=0,id=0;
    for(int i=0;i<v[x].size();i++){
        int y=v[x][i];
        if(y!=fa[x]){
            dep[y]=dep[x]+1;
            dfs1(y,x);
            size[x]+=size[y];
            if(size[y]>maxs)maxs=size[y],id=y;
        }
    }
    wc[x]=id;
    isw[id]=1;
}
int d[maxn],p[maxn],top[maxn],tot;
//d为dfs序,p为第i个数在dfs中的位置,top为链顶,tot为计算dfs序变量
void dfs2(int x){
    d[++tot]=x;
    top[x]=isw[x]?top[fa[x]]:x;
    for(int i=0;i<v[x].size();i++){
        int y=v[x][i];
        if(y!=fa[x])dfs2(y);
    }
}
int n,m,root,mod,val[maxn];
int w[4*maxn],tag[4*maxn];
//以下为线段树,其中[l,r]为现在的区间,[cl,cr]为询问区间,inr为完全包含,outr为完全不含
void pushup(int id){w[id]=(w[id*2]+w[id*2+1])%mod;}
void build(int id,int l,int r){
    if(l==r){w[id]=val[d[l]];return;}
    int mid=(l+r)/2;
    build(id*2,l,mid);
    build(id*2+1,mid+1,r);
    pushup(id);
}
void maketag(int id,int pos,int len){
    (tag[id]+=pos)%=mod;
    (w[id]+=(pos*len)%mod)%=mod;
}
void pushdown(int id,int l,int r){
    int mid=(l+r)/2;
    maketag(id*2,tag[id],mid-l+1);
    maketag(id*2+1,tag[id],r-mid);
    tag[id]=0;
}
bool inr(int l,int r,int cl,int cr){return cl<=l and r<=cr;}
bool outr(int l,int r,int cl,int cr){return cr<l or r<cl;}
void update(int id,int l,int r,int cl,int cr,int pos){
    if(inr(l,r,cl,cr)){maketag(id,pos,r-l+1);return;}
    if(outr(l,r,cl,cr))return;
    pushdown(id,l,r);
    int mid=(l+r)/2;
    update(id*2,l,mid,cl,cr,pos);
    update(id*2+1,mid+1,r,cl,cr,pos);
    pushup(id);
}
int check(int id,int l,int r,int cl,int cr){
    if(inr(l,r,cl,cr))return w[id];
    if(outr(l,r,cl,cr))return 0;
    pushdown(id,l,r);
    int mid=(l+r)/2;
    return (check(id*2,l,mid,cl,cr)+check(id*2+1,mid+1,r,cl,cr))%mod;
}
//线段树结束
void update(int x,int y,int pos){//将x到y上的节点加pos
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        update(1,1,n,p[top[x]],p[x],pos);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    update(1,1,n,p[x],p[y],pos);
}
int check(int x,int y){//查询x到y的权值和
    int ans=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        (ans+=check(1,1,n,p[top[x]],p[x]))%=mod;
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    (ans+=check(1,1,n,p[x],p[y]))%=mod;
    return ans;
}
signed main(){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cin>>n>>m>>root>>mod;
    for(int i=1;i<=n;i++)cin>>val[i],val[i]%=mod;
    for(int i=1;i<n;i++){
        int p,q;cin>>p>>q;
        v[p].push_back(q);
        v[q].push_back(p);
    }
    dfs1(root,0);
    dfs2(root);
    for(int i=1;i<=n;i++)p[d[i]]=i;
    build(1,1,n);
    while(m--){
        int op,x,y,z;cin>>op>>x;
        if(op==4)cout<<(check(1,1,n,p[x],p[x]+size[x]-1)+mod)%mod<<"\n";
        else{
            cin>>y;
            if(op==2)cout<<(check(x,y)+mod)%mod<<"\n";
            else if(op==3)y%=mod,update(1,1,n,p[x],p[x]+size[x]-1,y);
            else{
                cin>>z;z%=mod;
                update(x,y,z);
            }
        }
    }
    return 0;
}
//qwq

by UMS2 @ 2024-12-25 21:12:31

看起来是 dfs 序出问题了,应该优先遍历重儿子。@ARIS2_0


by vegetable_chili @ 2024-12-25 21:13:32

重链剖分 dfs2 要先遍历重儿子再遍历轻儿子,这样跑出来 d 和 p 的值才是对的。


by vegetable_chili @ 2024-12-25 21:15:12

这样是确保一条重链上的点的 p 值连续。


by ARIS2_0 @ 2024-12-25 21:15:58

@UMS2@vegetable_chili 感谢。自己怎么唐玩了。我改一下看看。


by ARIS2_0 @ 2024-12-25 21:17:47

@UMS2@vegetable_chili AC 了,万分感谢。此贴结。


|