树链剖分样例未过求调(含变量名解释)

P3384 【模板】重链剖分/树链剖分

Point_LUO @ 2024-08-08 23:17:21

thx&rp++

#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
int n,m,rt,mod,tmp[N],w[N];//tmp:按点原始序号记录的值 w:按dfn序(编号)记录的值 
int h[N],to[N<<1],nxt[N<<1],cnt;//链式前向星 
int tre[N<<2],tag[N<<2];//线段树 
int dep[N],f[N],s[N],siz[N];//点深度;父亲节点;重儿子节点;子树大小 
int d[N],dfn;//编号 
int st[N];//链起点 
void addedge(int u,int v)
{
    ++cnt;
    nxt[cnt]=h[u];
    h[u]=cnt;
    to[cnt]=v;
}
void init(int u,int fa)
{
    dep[u]=dep[fa]+1;
    f[u]=fa;
    siz[u]=1;
    int tmp=0;
    for(int i=h[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==fa) continue;
        init(v,u);
        siz[u]+=siz[v];
        if(tmp<siz[v])
        {
            tmp=siz[v];
            s[u]=v;
        }
    }
}
void dfs(int u,int t)
{
    d[u]=++dfn;
    w[dfn]=tmp[u];
    st[u]=t; 
    if(!s[u]) return;
    dfs(s[u],t);
    for(int i=h[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==f[u]||v==s[u]) continue;
        dfs(v,v);
    }
}
int ls(int x){ return x<<1; }
int rs(int x){ return x<<1|1; }
void pushup(int p){ tre[p]=(tre[ls(p)]+tre[rs(p)])%mod; }
void buildtree(int p,int l,int r)
{
    if(l==r)
    {
        tre[p]=w[l]%mod;
        return;
    }
    int m=(l+r)>>1;
    buildtree(ls(p),l,m);
    buildtree(rs(p),m+1,r);
    pushup(p);
}
void pushdown(int p,int l,int r)
{
    int m=(l+r)>>1;
    tre[ls(p)]=(tre[ls(p)]+tag[p]*(m-l+1))%mod;
    tre[rs(p)]=(tre[rs(p)]+tag[p]*(r-m))%mod;
    tag[ls(p)]+=tag[p]%mod;
    tag[rs(p)]+=tag[p]%mod;
    tag[p]=0;
}
void update(int p,int l,int r,int x,int y,int v)
{
    if(l>=x&&r<=y)
    {
        tre[p]+=(r-l+1)*v;
        tre[p]%=mod;
        tag[p]+=v;
        return;
    }
    pushdown(p,l,r);
    int ans=0;
    int m=(l+r)>>1;
    if(x<=m) update(ls(p),l,m,x,y,v);
    if(y>m) update(rs(p),m+1,r,x,y,v);
    pushup(p);
}
int query(int p,int l,int r,int x,int y)
{
    if(l>=x&&r<=y) return tre[p]%mod;
    pushdown(p,l,r);
    int ans=0;
    int m=(l+r)>>1;
    if(x<=m) ans=(ans+query(ls(p),l,m,x,y))%mod;
    if(y>m) ans=(ans+query(rs(p),m+1,r,x,y))%mod;
    return ans;
}
int qRange(int x,int y)
{
    int ans=0;
    while(st[x]!=st[y])
    {
        if(dep[st[x]]<dep[st[y]]) swap(x,y);
        ans+=query(1,1,n,d[st[x]],d[x]);
        ans%=mod;
        x=f[st[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    ans+=query(1,1,n,d[x],d[y]);
    ans%=mod;
    return ans;
}
void updRange(int x,int y,int k)
{
    k%=mod;
    while(st[x]!=st[y])
    {
        if(dep[st[x]]<dep[st[x]]) swap(x,y);
        update(1,1,n,d[st[x]],d[x],k);
        x=f[st[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    update(1,1,n,d[x],d[y],k);
}
int qSon(int x){ return query(1,1,n,d[x],d[x]+siz[x]-1); }
void updSon(int x,int k){ update(1,1,n,d[x],d[x]+siz[x]-1,k); }
int main()
{
    cin>>n>>m>>rt>>mod;
    for(int i=1;i<=n;i++) scanf("%d",&tmp[i]);
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        addedge(x,y);
        addedge(y,x);
    }
    init(rt,0);
    dfs(rt,rt);
    buildtree(rt,1,n);
    for(int i=1;i<=m;i++)
    {
        int opt;
        scanf("%d",&opt);
        if(opt==1)
        {
            int x,y,z;
            scanf("%d%d%d",&x,&y,&z);
            updRange(x,y,z);
        }
        else if(opt==2)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            printf("%d\n",qRange(x,y));
        }
        else if(opt==3)
        {
            int x,z;
            scanf("%d%d",&x,&z);
            updSon(x,z);
        }
        else
        {
            int x;
            scanf("%d",&x);
            printf("%d\n",qSon(x));
        }
    }
    return 0;
}

by Point_LUO @ 2024-08-09 14:44:15

已过,拜谢 @Toclhu 大佬,此贴结 /bx


|