重链剖分模板求助,悬关

P3384 【模板】重链剖分/树链剖分

017_007 @ 2023-08-25 13:42:29

19 pts,蒟蒻没看出来哪里出问题。

#include<bits/stdc++.h>
#define int long long
using namespace std;
inline int read() {
    int x=0,f=1;char s=getchar();
    while (s>'9'||s<'0') {
        if (s=='-') f=-f;
        s=getchar();
    }
    while (s>='0'&&s<='9') {
        x=(x<<1)+(x<<3)+s-'0';
        s=getchar();
    }
    return x*f;
}
const int N = 1e5+10;
int n,m,s,mod,a[N],u,v,first[N],cnt,op,x,y,z;
int dep[N],fa[N],zson[N],num[N],st[N],nid[N],nw[N],dnow;
int sum[N*4],add[N*4];
struct edge{
    int to,nxt;
}edges[N*2];
void Add(int u,int v) {
    edges[++cnt].to=v;
    edges[cnt].nxt=first[u];
    first[u]=cnt;   
}
void dfs1(int root,int last) {
    fa[root]=last;dep[root]=dep[last]+1;
    num[root]=1;
    nid[root]=++dnow;nw[dnow]=a[root];
    for (int t=first[root];t;t=edges[t].nxt) {
        int h=edges[t].to;
        if (h==last) continue;
        dfs1(h,root);
        num[root]+=num[h];
        if (num[h]>num[zson[root]]) zson[root]=h;
    }
}
void dfs2(int root,int topx) {
    st[root]=topx;
    if (!zson[root]) return;
    else dfs2(zson[root],topx);
    for (int t=first[root];t;t=edges[t].nxt) {
        int h=edges[t].to;
        if (h==zson[root]||h==fa[root]) continue;
        dfs2(h,h);
    }
}
void pushup(int k) { sum[k]=(sum[k<<1]+sum[k<<1|1])%mod;}
void build(int k,int l,int r) {
    if (l==r) {
        sum[k]=nw[l]%mod;
        return;
    }
    int mid=(l+r)>>1;
    build(k<<1,l,mid);
    build(k<<1|1,mid+1,r);
    pushup(k);
}
void pushdown(int k,int l,int r) {
    if (!add[k]) return;
    int mid=(l+r)>>1;
    sum[k<<1]=(sum[k<<1]+(mid-l+1)*add[k]%mod)%mod;
    sum[k<<1|1]=(sum[k<<1|1]+(r-mid)*add[k]%mod)%mod;
    add[k<<1]=(add[k<<1]+add[k])%mod;
    add[k<<1|1]=(add[k<<1|1]+add[k])%mod;
    add[k]=0;
}
void modify(int k,int l,int r,int zl,int zr,int v) {
    if (zl<=l&&zr>=r) {
        sum[k]=(sum[k]+(r-l+1)*v%mod)%mod;
        add[k]=(add[k]+v)%mod;
        return;
    } 
    pushdown(k,l,r);
    int mid=(l+r)>>1;
    if (mid>=zl) modify(k<<1,l,mid,zl,zr,v);
    if (mid<zr) modify(k<<1|1,mid+1,r,zl,zr,v);
    pushup(k); 
}
int ser(int k,int l,int r,int zl,int zr) {
    if (zl<=l&&zr>=r) return sum[k];
    pushdown(k,l,r);
    int mid=(l+r)>>1,res=0;
    if (mid>=zl) res=(res+ser(k<<1,l,mid,zl,zr))%mod;
    if (mid<zr) res=(res+ser(k<<1|1,mid+1,r,zl,zr))%mod;
    return res%mod;
}
void add_path(int x,int y,int z) { 
    int d1=x,d2=y;
    while (st[d1]!=st[d2]) {
        if (dep[st[d1]]<dep[st[d2]]) swap(d1,d2);
        modify(1,1,n,nid[st[d1]],nid[d1],z);
        d1=fa[st[d1]];
    }
    int l=min(nid[d1],nid[d2]),r=max(nid[d1],nid[d2]);
    modify(1,1,n,l,r,z);
}
int ser_path(int x,int y) {
    int ans=0,d1=x,d2=y;
    while (st[d1]!=st[d2]) {
        if (dep[st[d1]]<dep[st[d2]]) swap(d1,d2);
        ans=(ans+ser(1,1,n,nid[st[d1]],nid[d1]))%mod;
        d1=fa[st[d1]];
    }
    int l=min(nid[d1],nid[d2]),r=max(nid[d1],nid[d2]);
    ans=(ans+ser(1,1,n,l,r))%mod;
    return ans;
}
void add_tree(int root,int v) {
    int l=nid[root],r=nid[root]+num[root]-1;
    modify(1,1,n,l,r,v);
}
int ser_tree(int root) {
    int l=nid[root],r=nid[root]+num[root]-1;
    return ser(1,1,n,l,r);
}
signed main() {
    n=read();m=read();s=read();mod=read();
    for (int i=1;i<=n;++i) a[i]=read();
    for (int i=1;i<n;++i) u=read(),v=read(),Add(u,v),Add(v,u);
    dfs1(s,0);dfs2(s,s);
    build(1,1,n);
    while (m--) {
        op=read();
        if (op==1) {
            x=read();y=read();z=read();
            add_path(x,y,z);
        }
        if (op==2) {
            x=read();y=read();
            printf("%lld\n",ser_path(x,y)%mod);
        }
        if (op==3) {
            x=read();z=read();
            add_tree(x,z);
        }
        if (op==4) {
            x=read();
            printf("%lld\n",ser_tree(x)%mod);
        }
    }
    return 0;
}

by Iniaugoty @ 2023-08-25 13:46:30

@017_007 dfs序要在第二遍dfs时计算,你的代码里应该是nid数组


by 017_007 @ 2023-08-25 13:54:24

@gty314159 谢谢大佬,已关注


by 017_007 @ 2023-08-25 13:55:02

此贴结。(这是我第一次那么快结贴子


|