求调(最后三个点RE,73分!)

P3384 【模板】重链剖分/树链剖分

Jiangyx1028 @ 2024-08-07 19:41:06


#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=3E5+10,M=N*2;

int n,m,root,mod,cnt=0;

int h[N],e[M],ne[M],idx=0;
int son[N],fa[N],sz[N],w[N];
int nw[N],id[N],dep[N],top[N];

struct seg{
    int l,r,sum,flag;
}tr[N*4];

void add(int a,int b){
    e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}

void get_son(int u,int father,int depth){
    dep[u]=depth,fa[u]=father,sz[u]=1;

    for(int i=h[u];~i;i=ne[i]){
        int j=e[i];
        if(j==father)continue;
        get_son(j,u,depth+1);
        sz[u]+=sz[j];
        if(sz[son[u]]<sz[j])son[u]=j;
    }
}
void get_dfs(int u,int t){
    id[u]=++cnt,nw[cnt]=w[u],top[u]=t;
    if(!son[u])return ;
    get_dfs(son[u],t);

    for(int i=h[u];~i;i=ne[i]){
        int j=e[i];
        if(j==son[u]||j==fa[u])continue;
        get_dfs(j,j);
    }
}

void pushup(int p){
    tr[p].sum=(tr[p<<1].sum+tr[p<<1|1].sum)%mod;
}

void pushdown(int p){
    auto &root=tr[p],&left=tr[p<<1],&right=tr[p<<1|1];
    if(root.flag){
        left.sum=(left.sum+root.flag*(left.r-left.l+1))%mod;
        left.flag=(left.flag+root.flag)%mod;
        right.sum=(right.sum+root.flag*(right.r-right.l+1))%mod;
        right.flag=(right.flag+root.flag)%mod;
        root.flag=0;
    }
}

void build(int p,int l,int r){
    tr[p]={l,r,nw[r],0};
    if(l==r)return ;
    int mid=(l+r)>>1;
    build(p<<1,l,mid);
    build(p<<1|1,mid+1,r);
    pushup(p);
}

void update(int u,int l,int r,int k){
    if(l<=tr[u].l&&tr[u].r<=r){
        tr[u].flag=(tr[u].flag+k)%mod;
        tr[u].sum=(tr[u].sum+k*(tr[u].r-tr[u].l+1))%mod;
        return ;
    }
    pushdown(u);
    int mid=(tr[u].l+tr[u].r)>>1;
    if(l<=mid) update(u<<1,l,r,k);
    if(r>mid)update(u<<1|1,l,r,k);
    pushup(u);
}

int query(int u,int l,int r){
    if(l<=tr[u].l&&tr[u].r<=r)return tr[u].sum;
    pushdown(u);
    int mid=(tr[u].l+tr[u].r)>>1;
    int res=0;
    if(l<=mid) res=(res+query(u<<1,l,r))%mod;
    if(r>mid)res=(res+query(u<<1|1,l,r))%mod;
    return res%mod;
}

void update_path(int u,int v,int k){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])swap(u,v);
        update(root,id[top[u]],id[u],k);
        u=fa[top[u]];
    }
    if(dep[u]<dep[v]) swap(u,v);
    update(root,id[v],id[u],k);
}

int query_path(int u,int v){
    int res=0;
     while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])swap(u,v);
        res=(res+query(root,id[top[u]],id[u]))%mod;
        u=fa[top[u]];
    }
    if(dep[u]<dep[v])swap(u,v);
    res=(query(root,id[v],id[u])+res)%mod;
    return res%mod;
}

void update_tree(int u,int k){
    update(root,id[u],id[u]+sz[u]-1,k);  
}

int query_tree(int u){
    return query(root,id[u],id[u]+sz[u]-1);
}

signed main(){
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    memset(h,-1,sizeof h);

    scanf("%lld%lld%lld%lld",&n,&m,&root,&mod);
    for(int i=1;i<=n;i++)scanf("%lld",&w[i]);
    for(int i=1;i<n;i++){
        int a,b;scanf("%lld%lld",&a,&b);
        add(a,b);add(b,a);
    }

    get_son(root,-1,1);
    get_dfs(root,root);
    build(root,1,n);

    while(m--){
        int op,u,v,k;
        scanf("%lld",&op);
        if(op==1){
            scanf("%lld%lld%lld",&u,&v,&k);
            update_path(u,v,k);
        }else if(op==2){
            scanf("%lld%lld",&u,&v);
            printf("%lld\n",query_path(u,v)%mod);
        }else if(op==3){
            scanf("%lld%lld",&u,&k);
            update_tree(u,k);
        }else if(op==4){
            scanf("%lld",&u);
            printf("%lld\n",query_tree(u)%mod);
        }
    }
    return 0;
}

|