树剖没过样例求助

P3384 【模板】重链剖分/树链剖分

luqyou @ 2023-03-15 13:34:40

#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
vector<int> G[maxn];
int v[maxn],val[maxn],top[maxn],f[maxn],hson[maxn],size[maxn],dfn[maxn],rank[maxn],depth[maxn],cnt;
int n,m,root,mod;
struct node{
    int l,r,v,tag;
}a[maxn*4];
void dfs1(int u,int fa,int dep){
    f[u]=fa;
    depth[u]=dep;
    size[u]=1;
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i];
        if(v!=fa){
            dfs1(v,u,dep+1);
            size[u]+=size[v];
            if(size[v]>size[hson[u]]){
                hson[u]=v;
            } 
        }
    }
}
void dfs2(int u,int fa,int nowtop){
    top[u]=nowtop;
    dfn[u]=++cnt;
    rank[cnt]=u;
    if(hson[u]){
        dfs2(hson[u],u,nowtop);
        for(int i=0;i<G[u].size();i++){
            int v=G[u][i];
            if(v!=hson[u]&&v!=fa){
                dfs2(v,u,v);
            }
        }
    }
}
int ls(int u){
    return u<<1;
}
int rs(int u){
    return (u<<1)|1;
}
bool inrange(int L,int R,int l,int r){
    return (L<=l)&&(r<=R);
}
bool outofrange(int L,int R,int l,int r){
    return (R<l)||(r<L);
}
void build(int u,int L,int R){
    if(L!=R){
        int M=L+R>>1;
        build(ls(u),L,M);
        build(rs(u),M+1,R);
        a[u]=(node){L,R,(a[ls(u)].v%mod+a[rs(u)].v%mod)%mod,0};
    }
    else{
        a[u]=(node){L,R,v[u]%mod,0};
    }
}
void pushup(int u){
    a[u].v=(a[ls(u)].v%mod+a[rs(u)].v%mod)%mod;
}
void pushdown(int u){
    int L=a[u].l,R=a[u].r,M=L+R>>1,K=a[u].tag;
    if(L==R) return ;
    a[u].tag=0;
    a[ls(u)].tag+=K;
    a[ls(u)].tag%=mod;
    a[rs(u)].tag+=K;
    a[rs(u)].tag%=mod;
    a[ls(u)].v+=K*(M-L+1);
    a[ls(u)].v%=mod;
    a[rs(u)].v+=K*(R-M);
    a[rs(u)].v%=mod;
}
void update(int u,int L,int R,int k){
    if(a[u].tag) pushdown(u);
    if(inrange(L,R,a[u].l,a[u].r)){
        a[u].tag+=k;
        a[u].tag%=mod;
        a[u].v+=k*(a[u].r-a[u].l+1);
        a[u].v%=mod;
        pushdown(u);
    }
    else if(!outofrange(L,R,a[u].l,a[u].r)){
        update(ls(u),L,R,k);
        update(rs(u),L,R,k);
        pushup(u);
    }
}
int search(int u,int L,int R){
    if(a[u].tag) pushdown(u);
    if(inrange(L,R,a[u].l,a[u].r)){
        return a[u].v%mod;
    }
    else if(!outofrange(L,R,a[u].l,a[u].r)){
        return (search(ls(u),L,R)%mod+search(rs(u),L,R)%mod)%mod;
    }
    else return 0;
}
void updateintree(int x,int y,int k){
    while(top[x]!=top[y]){
        if(depth[top[x]]<depth[top[y]]){
            swap(x,y);
        }
        update(1,dfn[top[x]],dfn[x],k);
        x=f[top[x]];
    }
    if(depth[x]>depth[y]){
        swap(x,y);
    }
    update(1,dfn[x],dfn[y],k);
}
int queryintree(int x,int y){
    int ans=0;
    while(top[x]!=top[y]){
        if(depth[top[x]]<depth[top[y]]){
            swap(x,y);
        }
        ans+=search(1,dfn[top[x]],dfn[x]);
        ans%=mod;
        x=f[top[x]];
    }
    if(depth[x]>depth[y]){
        swap(x,y);
    }
    ans+=search(1,dfn[x],dfn[y]);
    return ans;
}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    cin>>n>>m>>root>>mod;
    for(int i=1;i<=n;i++){
        cin>>val[i];
    }
    for(int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs1(root,-1,1);
    dfs2(root,-1,-1);
    for(int i=1;i<=n;i++){
        v[i]=val[dfn[i]];
    }
    build(1,1,n);
    for(int i=1;i<=m;i++){
        int opt,x,y,z;
        cin>>opt;
        if(opt==1){
            cin>>x>>y>>z;
            updateintree(x,y,z%mod);
        }
        if(opt==2){
            cin>>x>>y;
            cout<<queryintree(x,y)<<endl;
        }
        if(opt==3){
            cin>>x>>y;
            update(1,dfn[x],dfn[x]+size[x]-1,y%mod);
        }
        if(opt==4){
            cin>>x;
            cout<<search(1,dfn[x],dfn[x]+size[x]-1)<<endl;
        }
    }
    return 0;
} 

|