37 pts 树剖求助

P3384 【模板】重链剖分/树链剖分

SilverLi @ 2023-05-13 18:11:16

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=1e6+5;
int n,m,r,mod,ax[N];
int dfn[N],Index,a[N];
int fa[N],d[N],top[N];
int son[N],si[N];
vector<int> g[N];
int t[N],ad1[N],ad2[N];
void add1(int p,int l,int r) {
    if(ad1[p]) {
        t[p<<1]+=ad1[p]*((r-l+1)>>1);
        t[p<<1|1]+=ad1[p]*((r-l+1)>>1);
        ad1[p<<1]+=ad1[p];
        ad1[p<<1|1]+=ad1[p];
        ad1[p]=0;
    }
}
void add2(int p,int l,int r) {
    if(ad2[p]) {
        t[p<<1]=ad2[p]*((r-l+1)>>1);
        t[p<<1|1]=ad2[p]*((r-l+1)>>1);
        ad2[p<<1]=ad2[p];
        ad2[p<<1|1]=ad2[p];
        ad2[p]=0;
        ad1[p<<1]=ad1[p<<1|1]=0;
    }
}
void down(int p,int l,int r) {add2(p,l,r),add1(p,l,r);}
void build(int l,int r,int p) {
    if(l==r) {  t[p]=a[l];  return; }
    int m=l+r>>1;
    build(l,m,p<<1);build(m+1,r,p<<1|1);
    t[p]=(t[p<<1]+t[p<<1|1]+mod)%mod;
}
void ADD(int l,int r,int S,int T,int p,int ch) {
    if(l>=S&&r<=T) {
        t[p]+=ch*(r-l+1),ad1[p]+=ch;
        return;
    }
    int m=l+r>>1;
    down(p,l,r);
    if(S<=m)    ADD(l,m,S,T,p<<1,ch);
    if(T>m) ADD(m+1,r,S,T,p<<1|1,ch);
    t[p]=(t[p<<1]+t[p<<1|1]+mod)%mod;
    return;
}
void COVER(int l,int r,int S,int T,int p,int ch) {
    if(l>=S&&r<=T) {
        t[p]=ch*(r-l+1),
        ad2[p]=ch,
        ad1[p]=0;
        return;
    }
    int m=l+r>>1;
    down(p,l,r);
    if(S<=m)    COVER(l,m,S,T,p<<1,ch);
    if(T>m) COVER(m+1,r,S,T,p<<1|1,ch);
    t[p]=(t[p<<1]+t[p<<1|1]+mod)%mod;
    return;
}
int SUM(int l,int r,int S,int T,int p) {
    if(l>=S&&r<=T)  return t[p];
    int m=l+r>>1,sum=0;
    down(p,l,r);
    if(S<=m)    sum=SUM(l,m,S,T,p<<1)%mod;
    if(T>m) sum+=SUM(m+1,r,S,T,p<<1|1);
    return (sum+mod)%mod;
}
void dfs(int v,int f) {
    d[v]=d[f]+1,si[v]=1,fa[v]=f;
    int mx=0;
    for(int i:g[v])
        if(i!=f) {
            dfs(i,v),si[v]+=si[i];
            if(mx<si[i])    son[v]=i,mx=si[i];
        }
}
void dfs2(int v,int deep) {
    dfn[v]=++Index,a[Index]=ax[v]%mod;
    top[v]=deep;
    if(!son[v]) return;
    dfs2(son[v],deep);
    for(int i:g[v])
        if(i!=fa[v]&&i!=son[v]) dfs2(i,i);
}
inline void add(int u,int v,int ch) {
    while(top[u]!=top[v]) {
        if(d[top[u]]<d[top[v]]) swap(u,v);
        ADD(1,n,dfn[top[u]],dfn[u],1,ch);
        u=fa[top[u]];
    }
    if(d[u]>d[v])   swap(u,v);
    ADD(1,n,dfn[u],dfn[v],1,ch);
}
inline void addsub(int u,int ch) {ADD(1,n,dfn[u],dfn[u]+si[u]-1,1,ch);}
inline int sum(int u,int v) {
    int res=0;
    while(top[u]!=top[v]) {
        if(d[top[u]]<d[top[v]]) swap(u,v);
        res+=SUM(1,n,dfn[top[u]],dfn[u],1);res%=mod;
        u=fa[top[u]];
    }
    if(d[u]>d[v])   swap(u,v);
    res+=SUM(1,n,dfn[u],dfn[v],1);
    return res%mod;
}
inline int sumsub(int u) {return SUM(1,n,dfn[u],dfn[u]+si[u]-1,1)%mod;}
signed main() {
    cin>>n>>m>>r>>mod;
    for(int i=1;i<=n;++i)   cin>>ax[i];
    for(int i=1;i<n;++i) {int x,y;cin>>x>>y;g[x].push_back(y),g[y].push_back(x);}
    dfs(r,0);dfs2(r,r);
    build(1,n,1);
    while(m--) {
        int opt,x,y,z;
        cin>>opt>>x;
        if(opt==1) {
            cin>>y>>z;
            add(x,y,z);
        } else if(opt==2) {
            cin>>y;
            cout<<sum(x,y)<<endl;
        } else if(opt==3) {
            cin>>z;
            addsub(x,z);
        } else  cout<<sumsub(x)<<endl;
    }
    return 0;
}

by Killer_joke @ 2023-05-14 22:00:26

@NM_ljy add1写的有问题,左右儿子的区间长度不一定是直接除二。


by SilverLi @ 2023-05-14 22:03:38

@Killer_joke thx,+已A

我在改以前用树状数组写的**代码


|