救救孩子37pts,wa on#4~#10,玄关

P3384 【模板】重链剖分/树链剖分

110821zj_hhx @ 2024-07-27 21:13:04

#include<iostream>
#include<cstdio>
#include<vector>
#define int long long
using namespace std;
int n,m,root,p,a[100005],tr[400005],lazy[400005],dfn[100005],tot,siz[100005],son[100005],dep[100005];
int top[100005],pa[100005],mp[100005];
vector<int>s[100005];
void build(int k,int l,int r){
    if(l==r){
        tr[k]=a[dfn[l]];
        return;
    }
    int mid=(l+r)/2;
    build(k*2,l,mid);
    build(k*2+1,mid+1,r);
    tr[k]=tr[k*2]+tr[k*2+1];
}
void change(int k,int l,int r,int x,int y,int v){
    if(r<x||l>y) return;
    if(l>=x&&r<=y){
        lazy[k]+=v;
        tr[k]+=v*(r-l+1);
        return;
    }
    if(lazy[k]){
        int mid=(l+r)/2;
        lazy[2*k]+=lazy[k],tr[2*k]+=lazy[k]*(mid-l+1);
        lazy[2*k+1]+=lazy[k],tr[2*k+1]+=lazy[k]*(r-mid);
        lazy[k]=0;
    }
    int mid=(l+r)/2;
    change(k*2,l,mid,x,y,v);
    change(k*2+1,mid+1,r,x,y,v);
    tr[k]=tr[2*k]+tr[2*k+1];
}
int find(int k,int l,int r,int x,int y){
    if(r<x||l>y) return 0;
    if(l>=x&&r<=y) return tr[k];
    int mid=(l+r)/2;
    lazy[2*k]+=lazy[k],tr[2*k]+=lazy[k]*(mid-l+1);
    lazy[2*k+1]+=lazy[k],tr[2*k+1]+=lazy[k]*(r-mid);
    lazy[k]=0;
    return find(k*2,l,mid,x,y)+find(k*2+1,mid+1,r,x,y);
}
void dfs1(int x,int fa){
    pa[x]=fa;
    dfn[++tot]=x;
    mp[x]=tot;
    dep[x]=dep[fa]+1;
    for(int i=0;i<s[x].size();i++){
        if(s[x][i]==fa) continue;
        dfs1(s[x][i],x);
        siz[x]+=siz[s[x][i]]; 
        if(siz[s[x][i]]>siz[son[x]]) son[x]=s[x][i]; 
    }
    siz[x]++;
}
void dfs2(int x,int tp){
    top[x]=tp;
    if(son[x]) dfs2(son[x],tp);
    for(int i=0;i<s[x].size();i++){
        if(s[x][i]==pa[x]||s[x][i]==son[x]) continue;
        dfs2(s[x][i],s[x][i]);
    }
}
signed main(){
    cin>>n>>m>>root>>p;
    for(int i=1;i<=n;i++) cin>>a[i];
    for(int i=1;i<n;i++){
        int x,y;
        cin>>x>>y;
        s[x].push_back(y);
        s[y].push_back(x);
    }
    dfs1(root,0);
    dfs2(root,root);
    build(1,1,n);
    while(m--){
        int op,x,y,z;
        cin>>op>>x;
        if(op==1){
            cin>>y>>z;
            while(top[x]!=top[y]){
                if(dep[x]<dep[y]) swap(x,y);
                change(1,1,n,mp[top[x]],mp[x],z);
                x=pa[top[x]];
            }
            if(dep[x]<dep[y]) swap(x,y);
            change(1,1,n,mp[y],mp[x],z);
        }else if(op==2){
            int ans=0;
            cin>>y;
            while(top[x]!=top[y]){
                if(dep[x]<dep[y]) swap(x,y);
                ans+=find(1,1,n,mp[top[x]],mp[x]);
                x=pa[top[x]];
            }
            if(dep[x]<dep[y]) swap(x,y);
            ans+=find(1,1,n,mp[y],mp[x]);
            cout<<ans%p<<endl;
        }else if(op==3){
            cin>>z;
            change(1,1,n,mp[x],mp[x]+siz[x]-1,z);
        }else{
            cout<<find(1,1,n,mp[x],mp[x]+siz[x]-1)%p<<endl;
        }
    }
    return 0;
}

|