树剖写死19pts求助

P3384 【模板】重链剖分/树链剖分

Libingyue2011 @ 2024-02-02 12:31:33

#include<bits/stdc++.h>
using namespace std;
int n,m,rt,p;
int tda[100010],a[100010];
int head[100010],to[200010],nxt[200010],tot;
void add(int u,int v){
    to[++tot]=v;
    nxt[tot]=head[u];
    head[u]=tot;
}
int fa[100010],son[100010],siz[100010],dep[100010];
void dfs(int x,int f){
    fa[x]=f,siz[x]=1;
    for(int i=head[x];i;i=nxt[i]){
        if(to[i]==f) continue;
        dep[to[i]]=dep[x]+1;
        dfs(to[i],x);
        siz[x]+=siz[to[i]];
        if(siz[son[x]]<siz[to[i]]) son[x]=to[i];
    }
}
int id[100010],top[100010],cnt;
void dfs2(int x,int t){
    id[x]=++cnt,a[id[x]]=tda[x],top[x]=t;
    if(!son[x]) return;
    dfs2(son[x],t);
    for(int i=head[x];i;i=nxt[i]){
        if(to[i]==son[x] || to[i]==fa[x]) continue;
        dfs2(to[i],to[i]);
    }
}
int t[400010],slow[400010];
void creat(int l,int r,int root){
    if(l==r){
        t[root]=a[l]%p;
        return; 
    }
    int mid=l+r>>1;
    creat(l,mid,root*2);
    creat(mid+1,r,root*2+1);
    t[root]=t[root*2]+t[root*2+1];
}
void pushdown(int root,int l,int r){
    if(!slow[root]) return;
    int mid=l+r>>1;
    slow[root*2]+=slow[root];
    slow[root*2+1]+=slow[root];
    t[root*2]+=slow[root]*(mid-l+1);
    t[root*2+1]+=slow[root]*(r-mid);
    t[root*2]%=p,slow[root*2]%=p;
    t[root*2+1]%=p,slow[root*2+1]%=p;
    slow[root]=0;
}
void update(int l,int r,int k,int x=1,int y=n,int root=1){
    if(l<=x && y<=r){
        t[root]+=(y-x+1)*k,slow[root]+=k;
        t[root]%=p,slow[root]%=p;
        return;
    }
    pushdown(root,x,y);
    int mid=x+y>>1;
    if(l<=mid) update(l,r,k,x,mid,root*2);
    if(mid<r) update(l,r,k,mid+1,y,root*2+1);
}
int sum(int l,int r,int x=1,int y=n,int root=1){
    if(l<=x && y<=r) return t[root];
    pushdown(root,x,y);
    int mid=x+y>>1,tot=0;
    if(l<=mid) tot=sum(l,r,x,mid,root*2);
    if(mid<r) tot=(tot+sum(l,r,mid+1,y,root*2+1))%p;
    return tot; 
}
void update_path(int u,int v,int k){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        update(id[top[u]],id[u],k);
        u=fa[top[u]];
    }
    if(dep[u]>dep[v]) swap(u,v);
    update(id[u],id[v],k);
}
int sum_path(int u,int v){
    int ans=0;
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        ans=(ans+sum(id[top[u]],id[u]))%p;
        u=fa[top[u]];
    }
    if(dep[u]>dep[v]) swap(u,v);
    ans=(ans+sum(id[u],id[v]))%p;
    return ans;
}
void update_tree(int u,int k){
    update(id[u],id[u]+siz[u]-1,k);
}
int sum_tree(int u){
    return sum(id[u],id[u]+siz[u]-1)%p;
}
signed main() {
    cin>>n>>m>>rt>>p;
    for(int i=1;i<=n;i++) cin>>tda[i];
    for(int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        add(u,v),add(v,u);
    }
    dfs(rt,0);
    dfs2(rt,0);
    creat(1,n,1);
    while(m--){
        int opt;
        cin>>opt;
        if(opt==1){
            int x,y,k;
            cin>>x>>y>>k;
            update_path(x,y,k);
        }
        if(opt==2){
            int x,y;
            cin>>x>>y;
            cout<<sum_path(x,y)<<"\n";
        }
        if(opt==3){
            int x,k;
            cin>>x>>k;
            update_tree(x,k);
        }
        if(opt==4){
            int x;
            cin>>x;
            cout<<sum_tree(x)<<"\n";
        }
    }
    return 0;
}

|