求助

P3384 【模板】重链剖分/树链剖分

_Dolphin_ @ 2023-05-06 16:12:03

#include<bits/stdc++.h>
#define DEBUG
#define mid ((l+r)>>1)
#define lson pos<<1
#define rson pos<<1|1
#define afor(x,y,z) for(int x=y;x<=z;x++)
#define sfor(x,y,z) for(int x=y;x>=z;x--)
using namespace std;
typedef double dou;
typedef long long ll;
typedef const int cint;
typedef unsigned int uint;
typedef const long long cll;
typedef unsigned long long ull;
cint N=1e5+10;
int n,m,r,mod,a[N];
int d[N<<2],b[N<<2];
int tot,head[N],to[N<<1],nxt[N<<1];
int cnt,fa[N],son[N],dep[N],siz[N],p[N],flag[N],bel[N];
void add(int x,int y) {
    to[++tot]=y;
    nxt[tot]=head[x];
    head[x]=tot;
}
void dfs1(int x,int y) {
    siz[x]=1;
    fa[x]=y;
    dep[x]=dep[y]+1;
    for(int i=head[x];i;i=nxt[i]) {
        int z=to[i];
        if(z!=y) {
            dfs1(z,x);
            siz[x]+=siz[z];
            if(siz[z]>siz[son[x]]) son[x]=z;
        }
    }
}
void dfs2(int x,int y) {
    p[x]=++cnt;
    flag[cnt]=x;
    bel[x]=y;
    if(!son[x]) return ;
    dfs2(son[x],y);
    for(int i=head[x];i;i=nxt[i]) {
        int z=to[i];
        if(z!=son[x]&&z!=fa[x]) dfs2(z,z);
    }
}
void build(int l,int r,int pos) {
    if(l==r) {
        d[pos]=a[flag[l]];
        return;
    }
    build(l,mid,lson);
    build(mid+1,r,rson);
    d[pos]=d[lson]+d[rson];
}
void pushdown(int l,int r,int pos) {
    d[lson]+=b[pos]*(mid-l+1);
    d[rson]+=b[pos]*(r-mid);
    b[lson]+=b[pos];
    b[rson]+=b[pos];
    b[pos]=0;
}
void update(int l,int r,int pos,int c,int h,int t) {
    if(h<=l&&r<=t) {
        d[pos]+=c*(r-l+1);
        b[pos]+=c;
        return;
    }
    if(b[pos]) pushdown(l,r,pos);
    if(h<=mid) update(l,mid,lson,c,h,t);
    if(t>mid) update(mid+1,r,rson,c,h,t);
    d[pos]=d[lson]+d[rson];
}
int querysum(int l,int r,int pos,int h,int t) {
    if(h<=l&&r<=t) return d[pos];
    if(b[pos]) pushdown(l,r,pos);
    int sum=0;
    if(h<=mid) sum+=querysum(l,mid,lson,h,t);
    if(t>mid) sum+=querysum(mid+1,r,rson,h,t);
    return sum;
}
void fix(int x,int y,int z) {
    while(bel[x]!=bel[y]) {
        if(dep[bel[x]]<dep[bel[y]]) swap(x,y);
        update(1,n,1,z,p[bel[x]],p[x]);
        x=fa[bel[x]];
    }
    if(dep[x]<dep[y]) swap(x,y);
    update(1,n,1,z,p[y],p[x]);
}
int getsum(int x,int y) {
    int sum=0;
    while(bel[x]!=bel[y]) {
        if(dep[bel[x]]<dep[bel[y]]) swap(x,y);
        sum+=querysum(1,n,1,p[bel[x]],p[x]);
        x=fa[bel[x]];
    }
    if(dep[x]<dep[y]) swap(x,y);
    sum+=querysum(1,n,1,p[y],p[x]);
    return sum;
}
int main() {
    int opt,x,y,z;
    scanf("%d%d%d%d",&n,&m,&r,&mod);
    afor(i,1,n) scanf("%d",&a[i]);
    dfs1(r,r);
    dfs1(r,0);
    build(1,n,1);
    afor(i,1,n-1) {
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    afor(i,1,m) {
        scanf("%d",&opt);
        if(opt==1) {
            scanf("%d%d%d",&x,&y,&z);
            fix(x,y,z);
        }
        else if(opt==2) {
            scanf("%d%d",&x,&y);
            printf("%d\n",getsum(x,y)%mod);
        }
        else if(opt==3) {
            scanf("%d%d",&x,&z);
            update(1,n,1,z,p[x],p[x]+siz[x]-1);
        }
        else {
            scanf("%d",&x);
            printf("%d\n",querysum(1,n,1,p[x],p[x]+siz[x]-1)%mod);
        }
    }
    return 0;
}

为什么炸了?


by LgxTpre @ 2023-05-06 16:31:27

@Dolphin0613 你为什么先剖再加边?为什么两遍dfs1?


by _Dolphin_ @ 2023-05-07 13:30:01

@LgxTpre 谢谢,当时人傻了


|