10分求调

P3384 【模板】重链剖分/树链剖分

lostinue @ 2024-03-13 15:57:04

#include<bits/stdc++.h>
#define ll long long
#define debug printf("ciallo\n")
#define rep(i,aaa,bbb) for(int i=(aaa);i<=(bbb);i++)
#define per(i,aaa,bbb) for(int i=(aaa);i>=(bbb);i--)
#define pb push_back
using namespace std;

int read(){
    int f=1,x=0;
    char c=getchar();
    while(c>'9'||c<'0'){
        if(c=='-') f=-1;
        c=getchar();
    }
    while(c>='0'&&c<='9'){
        x=(x<<1)+(x<<3)+(c^48);
        c=getchar();
    }
    return x*f;
}

const int inf=0x7fffffff;
const int maxn=1e5+10;

vector<int>t[maxn];
int n,m,root,mod;
int w[maxn],cnt;
int dep[maxn],fa[maxn],size[maxn],son[maxn],top[maxn];//true index
int id[maxn],wt[maxn];//index in T,using Tfunction need id[x]
//xds------------------------------------
struct node{
    ll x,add;int l,r;
}T[maxn<<2];

void build(int x,int l,int r){
    T[x].l=l,T[x].r=r;
    if(l==r){
        T[x].x=wt[l];
        return;
    }
    int mid=(T[x].l+T[x].r)>>1;
    build(x<<1,l,mid);
    build(x<<1|1,mid+1,r);
    T[x].x=(T[x<<1].x+T[x<<1|1].x)%mod;
}

inline void down(int x){
    if(T[x].add==0)return;
    T[x<<1].add+=T[x].add;
    T[x<<1|1].add+=T[x].add;
    T[x<<1].x+=T[x].add*(T[x<<1].r-T[x<<1].l+1);
    T[x<<1|1].x+=T[x].add*(T[x<<1|1].r-T[x<<1|1].l+1);
    T[x<<1].add%=mod;
    T[x<<1|1].add%=mod;
    T[x<<1].x%=mod;
    T[x<<1|1].x%=mod;
    T[x].add=0;
}

void rewrite(int x,int lq,int rq,int vl){
    if(T[x].l>=lq&&T[x].r<=rq){
        T[x].add=(T[x].add+vl)%mod;
        T[x].x=(T[x].x+(T[x].r-T[x].l+1)*vl)%mod;
        return;
    }
    down(x);
    int mid=(T[x].l+T[x].r)>>1;
    if(lq<=mid)rewrite(x<<1,lq,rq,vl);
    if(rq>mid)rewrite(x<<1|1,lq,rq,vl);
}

int count(int x,int lq,int rq){
    if(T[x].l>=lq&&T[x].r<=rq)return T[x].x;
    ll ccc=0;
    int mid=(T[x].l+T[x].r)>>1;
    down(x);
    if(lq<=mid)ccc+=count(x<<1,lq,rq);
    ccc%=mod;
    if(rq>mid)ccc+=count(x<<1|1,lq,rq);
    return ccc%mod;
}
//---------------------------------------
void dfs1(int x,int f,int deep){
    dep[x]=deep;
    fa[x]=f;
    size[x]=1;
    int mmm=-1;
    for(int y:t[x]){
        if(y!=f){
            dfs1(y,x,deep+1);
            size[x]+=size[y];
            if(size[y]>mmm)mmm=size[y],son[x]=y;
        }
    }
}

void dfs2(int x,int f){
    id[x]=++cnt;
    wt[cnt]=wt[x];
    top[x]=f;
    if(!son[x])return;
    dfs2(son[x],f);
    for(int y:t[x]){
        if(y!=fa[x]&&y!=son[x]){
            dfs2(y,y);
        }
    }
}

inline void q1(int x,int y,int z){
    z%=mod;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        rewrite(1,id[top[x]],id[x],z);
        x=fa[top[x]];
    }
    if(dep[x]<dep[y])swap(x,y);
    rewrite(1,id[y],id[x],z);
}

inline int q2(int x,int y){
    ll ans=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        ans=(ans+count(1,id[top[x]],id[x]))%mod;
        x=fa[top[x]];
    }
    if(dep[x]>dep[y])swap(x,y);
    return (ans+count(1,id[x],id[y]))%mod;
}

inline void q3(int x,int z){
    rewrite(1,id[x],id[x]+size[x]-1,z);
}

inline int q4(int x){
    return count(1,id[x],id[x]+size[x]-1);
}

int main(){
    n=read(),m=read(),root=read(),mod=read();
    rep(i,1,n)w[i]=read()%mod;
    rep(i,2,n){
        int u=read(),v=read();
        t[u].pb(v),t[v].pb(u);
    }
    dfs1(root,0,0);
    dfs2(root,0);
    build(1,1,n);
    while(m--){
        int k=read();
        if(k==1){
            int x=read(),y=read(),z=read();
            q1(x,y,z);
        }
        if(k==2){
            int x=read(),y=read();
            printf("%d\n",q2(x,y));
        }
        if(k==3){
            int x=read(),z=read();
            q3(x,z);
        }
        if(k==4){
            int x=read();
            printf("%d\n",q4(x));
        }
    }
    return 0;
}

|