求助一道简单题

P3384 【模板】重链剖分/树链剖分

lhrfc @ 2023-04-30 13:50:43

rt

#include <bits/stdc++.h>
using namespace std;
const int N=1e5+10,inf=0x3f3f3f3f;
int n,tot=0,dfn[N],id[N],m,p,r;
struct node{
    vector<int>e;
    int fa,dep,maxson,size,id,top,w;
}tr[N];
void dfs1(int x,int fa,int dep){
    tr[x].fa=fa,tr[x].dep=dep,tr[x].size=1;
    int maxx=-1;
    for(auto it:tr[x].e){
        if(it==fa) continue;
        dfs1(it,x,dep+1);
        tr[x].size+=tr[it].size;
        if(maxx<tr[it].size) maxx=tr[it].size,tr[x].maxson=it;
    }
}
void dfs2(int x,int top){
    tr[x].top=top,dfn[++tot]=tr[x].w,id[tot]=x,tr[x].id=tot;
    if(!tr[x].maxson) return;
    dfs2(tr[x].maxson,top);
    for(auto it:tr[x].e){
        if(it==tr[x].fa||it==tr[x].maxson) continue;
        dfs2(it,it);
    }
}
class XDS{
public:
    struct node{
        int sum,l,r,tag;
    }tr[N*4];
    inline void pushup(int x){
        tr[x].sum=tr[x*2].sum+tr[x*2+1].sum;
    }
    inline void pushdown(int x){
        if(tr[x].tag){
            tr[x*2].tag+=tr[x].tag;
            tr[x*2+1].tag+=tr[x].tag;
            tr[2*x].sum+=tr[x].tag*(tr[2*x].r-tr[2*x].l+1);
            tr[2*x+1].sum+=tr[x].tag*(tr[2*x+1].r-tr[2*x+1].l+1);
            tr[x].tag=0;
        }
    }
    void build(int x,int l,int r){
        tr[x].l=l,tr[x].r=r,tr[x].tag=0;
        if(l==r){
            tr[x].sum=dfn[l];
            return;
        }
        int mid=(l+r)/2;
        build(x*2,l,mid),build(x*2+1,mid+1,r);
        pushup(x);
    }
    int query(int x,int l,int r){
        if(tr[x].l>=l&&tr[x].r<=r) return tr[x].sum;
        pushdown(x);
        int mid=(tr[x].l+tr[x].r)/2,sum=0;
        if(l<=mid) sum=query(x*2,l,r)%p;
        if(r>mid) sum+=query(x*2+1,l,r);
        return sum%p;
    }
    void change(int x,int l,int r,int k){
        if(tr[x].l>=l&&tr[x].r<=r){
        tr[x].tag+=k;   
            tr[x].sum+=k*(tr[x].r-tr[x].l+1);
            return;
        }
        pushdown(x);
        int mid=(tr[x].l+tr[x].r)/2;
        if(l<=mid) change(x*2,l,r,k);
        if(r>mid) change(x*2+1,l,r,k);
        pushup(x);
    }
};
XDS xds;
inline void f1(int x,int y,int z){
    z%=p;
    while(tr[x].top!=tr[y].top){
        if(tr[tr[x].top].dep>tr[tr[y].top].dep) swap(x,y);
        int tx=tr[x].top;
        xds.change(1,tr[x].id,tr[tx].id,z);
        x=tr[tx].fa;
    }
    if(tr[x].dep>tr[y].dep) swap(x,y);
    xds.change(1,tr[x].id,tr[y].id,z);
}
inline int f2(int x,int y){
    int ans=0;
    while(tr[x].top!=tr[y].top){
        if(tr[tr[x].top].dep>tr[tr[y].top].dep) swap(x,y);
        int tx=tr[x].top;
        ans+=xds.query(1,tr[x].id,tr[tx].id)%p;
        ans%=p;
        x=tr[tx].fa;
    }
    if(tr[x].dep>tr[y].dep) swap(x,y);
    return ans+xds.query(1,tr[x].id,tr[y].id)%p;    
}
inline void f3(int x,int z){
    int l=x,r=l+tr[x].size-1;
    xds.change(1,l,r,z%p);
}
inline int f4(int x){
    int l=x,r=l+tr[x].size-1;
    return xds.query(1,l,r)%p;
}
int main(){
    cin>>n>>m>>r>>p;
    for(int i=1;i<=n;i++) cin>>tr[i].w;
    for(int i=1;i<=n-1;i++){
        int a,b;
        cin>>a>>b;
        tr[a].e.push_back(b),tr[b].e.push_back(a);
    }
    dfs1(r,r,1);
    dfs2(r,r);
    xds.build(1,1,n);
    while(m--){
        int op,x,y,z;
        cin>>op;
        switch(op){
        case 1:
            cin>>x>>y>>z;
            f1(x,y,z);
            break;
        case 2:
            cin>>x>>y;
            cout<<f2(x,y)%p<<endl;
            break;
        case 3:
            cin>>x>>z;
            f3(x,z);
            break;
        case 4:
            cin>>x;
            cout<<f4(x)%p<<endl;
            break;
        }
    }
    return 0;
}

by heavenjcy @ 2023-04-30 14:45:09

qp


|