WA13pts求调

P3384 【模板】重链剖分/树链剖分

florrer_cy @ 2024-12-21 08:51:41

#include<bits/stdc++.h>
using namespace std;
#define N 100005
int add[N<<2],sum[N<<2],s[N<<2],t[N<<2];
int dep[N],fa[N],size[N],son[N],dfn[N],a[N],top[N],id[N];
int tot,n,m,r,p;
vector<int>e[N];
void dfs1(int,int);
void dfs2(int,int);
void build(int,int,int);
void pushdown(int);
int getsum1(int,int,int);
int getsum2(int,int);
void update_add1(int,int,int,int);
void update_add2(int,int,int);
int main(){
    ios::sync_with_stdio(0);
    cin>>n>>m>>r>>p;
    for(int i=1;i<=n;i++){
        cin>>a[i];
        a[i]%=p;
    }
    for(int i=1,u,v;i<n;i++){
        cin>>u>>v;
        e[u].push_back(v);
        e[v].push_back(u);
    }
    dfs1(r,0);
    dfs2(r,r);
    build(1,1,n);
    //cout<<size[4]<<endl;
    for(int type,x,y,z;m--;){
        cin>>type>>x;
        switch(type){
            case 1:
                cin>>y>>z;
                update_add2(x,y,z%p);
                break;
            case 2:
                cin>>y;
                cout<<getsum2(x,y)<<endl;
                break;
            case 3:
                cin>>z;
                update_add1(1,dfn[x],dfn[x]+size[x]-1,z%p);
                break;
            case 4:
                cout<<getsum1(1,dfn[x],dfn[x]+size[x]-1)<<endl;
        }
    }
    return 0;
}
void dfs1(int u,int f){
    fa[u]=f;
    dep[u]=dep[f]+1;
    size[u]=1;
    for(auto v:e[u]){
        if(v==f) continue;
        dfs1(v,u);
        size[u]+=size[v];
        if(size[v]>size[son[u]]) son[u]=v;
    }
}
void dfs2(int u,int t){
    dfn[u]=++tot;
    id[dfn[u]]=u;
    top[u]=t;
    if(son[u]==0) return;
    dfs2(son[u],t);
    for(auto v:e[u]){
        if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);
    }
}
void build(int tid,int l,int r){
    s[tid]=l;t[tid]=r;
    if(l==r){s[tid]=t[tid]=l;sum[tid]=a[id[l]]%p;return;}
    int mid=(l+r)>>1;
    build(tid<<1,l,mid);
    build((tid<<1)|1,mid+1,r);
    sum[tid]=(1ll*sum[tid<<1]+sum[(tid<<1)|1])%p;
}
void pushdown(int tid){
    if(add[tid]==0) return;
    sum[tid<<1]=(sum[tid<<1]+1ll*(t[tid<<1]-s[tid<<1]+1)*add[tid]%p)%p;
    sum[(tid<<1)|1]=(sum[(tid<<1)|1]+1ll*(t[(tid<<1)|1]-s[(tid<<1)|1]+1)*add[tid]%p)%p;
    add[tid<<1]=(1ll*add[tid<<1]+add[tid])%p;
    add[(tid<<1)|1]=(1ll*add[(tid<<1)|1]+add[tid])%p;
    add[tid]=0;
}
int getsum1(int tid,int l,int r){
    if(s[tid]>=l&&t[tid]<=r) return sum[tid];
    int mid=(s[tid]+t[tid])>>1,ret=0;
    pushdown(tid);
    if(l<=mid) ret=getsum1(tid<<1,l,mid);
    if(r>mid) ret=(ret+1ll*getsum1((tid<<1)|1,mid+1,r))%p;
    return ret;
}
int getsum2(int u,int v){
    int ret=0;
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        ret=(ret+1ll*getsum1(1,dfn[top[u]],dfn[u]))%p;
        //cerr<<top[u]<<" "<<u<<" ";
        //cerr<<getsum1(1,dfn[top[u]],dfn[u])<<endl;
        u=fa[top[u]];
    }
    if(dep[u]>dep[v]) swap(u,v);
    ret=(ret+1ll*getsum1(1,dfn[u],dfn[v]))%p;
   // cerr<<u<<" "<<v<<" ";
   // cerr<<getsum1(1,dfn[v],dfn[u])<<endl;
    return ret;
}
void update_add1(int tid,int l,int r,int k){
    if(s[tid]>=l&&t[tid]<=r){
        sum[tid]=(sum[tid]+1ll*(t[tid]-s[tid]+1)*k%p)%p;
        add[tid]=(add[tid]+1ll*k)%p;
        return;
    }
//  //cout<<s[tid]<<" "<<t[tid]<<endl;
    int mid=(s[tid]+t[tid])>>1;
    pushdown(tid);
    if(l<=mid) update_add1(tid<<1,l,r,k);
    if(r>mid) update_add1((tid<<1)|1,l,r,k);
}
void update_add2(int u,int v,int k){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        update_add1(1,dfn[top[u]],dfn[u],k);
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])    swap(u,v);
    update_add1(1,dfn[u],dfn[v],k);
}

|