树链剖分板子求调 30pts 样例过了

P3384 【模板】重链剖分/树链剖分

itshawn @ 2023-04-20 16:49:02

调对的人奖励全机房人关注他

#include <bits/stdc++.h>
#define int long long
using namespace std;
int n,m,r,a[10000001],mod;
vector <int> e[1000001];
int depth[1000001],size[1000001],fa[10000001],son[10000001];
int rev[1000001],val[1000001];
void dfs(int x,int fath){
    fa[x]=fath;
    size[x]=1;
    depth[x]=depth[fath]+1;
    for(int i=0;i<e[x].size();i++){
        int v=e[x][i];
        if(v==fath) continue;
        dfs(v,x);
        size[x]+=size[v];
        if(size[son[x]]<size[v]||son[x]==0) son[x]=v;// 
    }
}
int id[1000001],top[1000001],dfn;
void treecut(int x,int t){
    top[x]=t;
    id[x]=++dfn;
    rev[dfn]=val[x];
    if(son[x]==0) return;
    treecut(son[x],t);
    for(int i=0;i<e[x].size();i++){
        int v=e[x][i];
        if(v==son[x]||v==fa[x]) continue;
        treecut(v,v);
    }
} 
struct tree{
    int sum,l,r,lz;
}tr[1000001];
void built(int p,int l,int r){
    tr[p].l=l;
    tr[p].r=r;
    if(l==r){
        tr[p].sum=rev[l]%mod;
        return;
    }
    int mid=(l+r)/2;
    built(p*2,l,mid);
    built((p*2)+1,mid+1,r);
    tr[p].sum=(tr[p*2].sum%mod+tr[(p*2)+1].sum%mod)%mod;
}
void uplz(int p,int b){
    tr[p].lz+=b;tr[p].lz%=mod;
    tr[p].sum=(tr[p].sum%mod+(tr[p].r-tr[p].l+1)%mod*b%mod)%mod;
}
void push(int p){
    if(tr[p].lz){
        int b=tr[p].lz;
        uplz(p*2,b);
        uplz((p*2)+1,b);
        tr[p].lz=0;
    }
}
void modify(int p,int l,int r,int b){
    if(tr[p].l>r||tr[p].r<l) return;
    if(l<=tr[p].l&&tr[p].r<=r){
        tr[p].lz+=b;tr[p].lz%=mod;
        tr[p].sum=(tr[p].sum%mod+(tr[p].r-tr[p].l+1)%mod*b%mod)%mod;
        return;
    }
    push(p);
    modify(p*2,l,r,b);
    modify(p*2+1,l,r,b);
    tr[p].sum=(tr[p*2].sum%mod+tr[p*2+1].sum%mod)%mod;
}
void tr_add(int u,int v,int w){
    while(top[u]!=top[v]){
        if(depth[u]<depth[v]) swap(u,v);
        modify(1,id[top[u]],id[u],w);
        u=fa[top[u]];
    }
    if(depth[u]>depth[v]) swap(u,v);
    modify(1,id[u],id[v],w);
}
int query(int p,int l,int r){
    if(tr[p].l>r||tr[p].r<l){
        return 0;
    }
    if(l<=tr[p].l&&tr[p].r<=r){
        return tr[p].sum%mod;
    }
    push(p);
    return (query(p*2,l,r)%mod+query(p*2+1,l,r)%mod)%mod;
}
int get_ans(int u,int v){
    int ans=0;
    while(top[u]!=top[v]){
        if(depth[u]<depth[v]) swap(u,v);
        ans=(ans%mod+query(1,id[top[u]],id[u])%mod)%mod;
        u=fa[top[u]];
    }
    if(depth[u]>depth[v]) swap(u,v);
    ans=(ans%mod+query(1,id[u],id[v])%mod)%mod;
    return ans%mod;
}
signed main(){
    int x,y,op,z;
    cin>>n>>m>>r>>mod;
    for(int i=1;i<=n;i++){
        cin>>val[i];
    }
    for(int i=1;i<n;i++){
        scanf("%d%d",&x,&y);
        e[x].push_back(y);
        e[y].push_back(x);
    }
    dfs(r,r);
    treecut(r,r);
    built(1,1,n);
    for(int i=1;i<=m;i++){
        cin>>op;
        if(op==1){
            cin>>x>>y>>z;
            tr_add(x,y,z%mod);
        }
        if(op==2){
            cin>>x>>y;
            cout<<get_ans(x,y)<<endl;
        }
        if(op==3){
            cin>>x>>z;
            modify(1,id[x],id[x]+size[x]-1,z%mod);
        }
        if(op==4){
            cin>>x;
            cout<<query(1,id[x],id[x]+size[x]-1)%mod<<endl;
        }
    }
    return 0;
}

by LuckiestShawn @ 2024-02-21 11:27:50

呜呜呜,现在重写一遍就37pts了


|