求助,案例过不了

P3384 【模板】重链剖分/树链剖分

Willan_Lian @ 2023-12-03 13:20:41


#include<bits/stdc++.h>
using namespace std;
int n,m,rt;
int w[100010],cnt=0;
long long res=0;
int MOD;
struct E{
    int to,nxt;
} E[200020];

int head[100010];
void add_edge(int u,int v){
    E[++cnt].nxt=head[u];
    E[cnt].to=v;
    head[u]=cnt;
}

int deep[100010],fath[100010];
int hson[100010],size[100010];
int top[100010],dfn[100010];
int pre[100010];

void dfs1(int u,int fa) {
    fath[u]=fa;
    deep[u]=deep[fa]+1;
    size[u]=1; 
    for(int i=head[u]; ~i; i=E[i].nxt) {
        int to=E[i].to;
        if(to==fa) continue;
        dfs1(to,u);
        size[u]+=size[to];

        if(!hson[u] || size[to]>size[hson[u]])
            hson[u]=to;
    }
}//tarjan跑一遍明确父子关系,重链轻链
int tot=0;
void dfs2(int u,int num) {
    top[u]=num;//记录这条链的链顶
    dfn[u]=++tot;//节点在树中的新编号 
    pre[tot]=w[u];//u点在原始的树种的编号pre[u]; 

    if(!hson[u]) return;
    dfs2(hson[u],num);

    for(int i=head[u]; ~i; i=E[i].nxt) {
        int to=E[i].to;
        if(to!=hson[u] && to!=fath[u]) dfs2(to,to);
    }
}
//以上为树剖部分

long long C[400040],b[400040];
void insert(int lr,int rr,int idx) {
    if(lr==rr){
        C[idx]=pre[lr];
        if(C[idx]>MOD) C[idx]=C[idx]%MOD;
        return ;
    }
    int mid=lr+(rr-lr>>1);
    insert(lr,mid,idx<<1);
    insert(mid+1,rr,(idx<<1)+1);
    C[idx]=(C[idx<<1]+C[(idx<<1)+1])%MOD;
}

long long Sum(int lr,int rr,int ST,int ED,int idx) {
    long long Ans=0;
    if(lr<=ST && ED<=rr){
        Ans+=C[idx];
        Ans=Ans%MOD;
        return Ans;
    } 
    int mid=ST+(ED-ST>>1);
    if(b[idx]){
        b[idx<<1]=b[idx]+b[idx<<1];
        b[(idx<<1)+1]=b[idx]+b[(idx<<1)+1];

        C[idx<<1]+=b[idx]*(mid-ST+1);
        C[(idx<<1)+1]+=b[idx]*(ED-mid);

        C[idx<<1]%=MOD;
        C[(idx<<1)+1]%=MOD;
        b[idx]=0;
    }

    if(lr<=mid) Ans+=Sum(lr,rr,ST,mid,idx<<1);
    if(rr>mid) Ans+=Sum(lr,rr,mid+1,ED,(idx<<1)+1);
    return Ans;
}

void upd(int lr, int rr, int ST, int ED, int idx,long long d){
    if(lr<=ST && ED<=rr){
        b[idx]+=d;
        C[idx]=(ED-ST+1)*d;
        return;
    }
    int mid=ST+((ED-ST)>>1);
    if(b[idx]){
        b[idx<<1]+=b[idx];
        b[(idx<<1)+1]+=b[idx];

        C[idx<<1]+=b[idx]*(mid-ST);
        C[(idx<<1)+1]+=b[idx]*(ED-mid+1);

        C[idx<<1]%=MOD;
        C[(idx<<1)+1]%=MOD;
        b[idx]=0;
    }

    if(lr<=mid) upd(lr,rr,ST,mid,idx<<1,d);
    if(rr>mid) upd(lr,rr,mid+1,ED,(idx<<1)+1,d);
    C[idx]=(C[idx<<1]+C[(idx<<1)+1])%MOD;
}
//以上为线段树部分

long long range(int x,int y){
    long long Ans=0;
    while(top[x]!=top[y]){
        if(deep[top[x]]<deep[top[y]]) swap(x,y);
        res=0;
        res=Sum(dfn[top[x]],dfn[x],1,n,1);
        Ans+=res;
        Ans=Ans%MOD;
        x=fath[top[x]];
    }
    if(deep[x]>deep[y]) swap(x,y);

    res=0;
    res=Sum(dfn[x],dfn[y],1,n,1);

    Ans+=res;
    return Ans%MOD;
}

void update_range(int x,int y,long long k){
    k%=MOD;
    while(top[x]!=top[y]){
        if(deep[top[x]]<deep[top[y]]) swap(x,y);
        upd(dfn[top[x]],dfn[x],1,n,1,k);
        x=fath[top[x]];
    }

    if(deep[x]>deep[y]) swap(x,y);
    upd(dfn[x],dfn[y],1,n,1,k);
}

long long Rtsum(int x){
    return Sum(dfn[x],dfn[x]+size[x]-1,1,n,1);
}

void update_son(int x,long long k){
    upd(dfn[x],dfn[x]+size[x]-1,1,n,1,k);
}

//以上为树上的链转化为线段树的部分

int main() {
    memset(head,-1,sizeof(head));
    scanf("%d%d%d%d",&n,&m,&rt,&MOD);
    for(int i=1; i<=n; i++)
        scanf("%d",&w[i]);

    for(int i=1; i<n; i++){
        int x,y;
        scanf("%d%d",&x,&y);
        add_edge(x,y);
        add_edge(y,x);
    }

    dfs1(rt,0);
    dfs2(rt,rt);
    insert(1,n,1);

    for(int i=1; i<=m; i++){
        int op;
        scanf("%d",&op);

        if(op==1) {
            int u,v;
            long long w;
            scanf("%d%d%lld",&u,&v,&w);
            update_range(u,v,w);
        }

        if(op==2) {
            int u,v;
            scanf("%d%d",&u,&v);
            printf("%lld\n",range(u,v));
        }

        if(op==3){
            int x;
            long long y;
            scanf("%d%lld",&x,&y);
            update_son(x,y);
        }

        if(op==4) {
            int x;
            scanf("%d",&x);
            printf("%lld\n",Rtsum(x));
        }
    }
    return 0;
}

by __Chx__ @ 2023-12-07 19:07:17

@Willan_Lian

发现了两处小问题:

C[idx]=(ED-ST+1)*d;
C[idx<<1]+=b[idx]*(mid-ST);
C[(idx<<1)+1]+=b[idx]*(ED-mid+1);

应改为:

C[idx<<1]+=b[idx]*(mid-ST+1);
C[(idx<<1)+1]+=b[idx]*(ED-mid);

最后在输出的时候记得取个模就能AC了


by Willan_Lian @ 2023-12-08 17:54:53

xx


|