样例过不了 0pts

P3384 【模板】重链剖分/树链剖分

Shadow_Lord @ 2023-07-07 10:44:01

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=1e5+5;
inline int read()
{
    int s=0,w=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9')s=(s<<3)+(s<<1)+(ch^48),ch=getchar();
    return s*w;
}
int n,m,root,mod,rt,v[N];
int head[N<<1],cnt,f[N],d[N],sz[N];
int top[N],rk[N],lazy[N<<1],son[N],id[N];
struct node{
    int to,net;
}e[N<<1];
struct Node{
    int l,r,ls,rs,sum;
}a[N<<1];
inline void add(int x,int y)
{
    e[++cnt].to=y;
    e[cnt].net=head[x];
    head[x]=cnt;
}
inline void dfs1(int fa,int x)
{
    f[x]=fa;d[x]=d[f[x]]+1;
    sz[x]=1;
    for(int i=head[x];i;i=e[i].net)
    {
        int to=e[i].to;
        if(to==fa)continue;
        dfs1(x,to);
        sz[x]+=sz[to];
        if(sz[to]>sz[son[x]])
        {
            son[x]=to;
        }
    }
}
inline void dfs2(int t,int x)
{
    top[x]=t;id[x]=++cnt;
    rk[cnt]=x;
    if(!son[x])return;
    dfs2(t,son[x]);
    for(int i=head[x];i;i=e[i].net)
    {
        int to=e[i].to;
        if(to==son[x]||to==f[x])continue;
        dfs2(to,to);
    }
}
inline int len(int x)
{
    return a[x].r-a[x].l+1;
}
inline void pushup(int x)
{
    a[x].sum=a[a[x].ls].sum+a[a[x].rs].sum;
    a[x].sum%=mod;
}
inline void pushdown(int x)
{
    if(lazy[x])
    {
        int ls=a[x].ls,rs=a[x].rs,s=lazy[x];
        lazy[ls]+=s;lazy[ls]%=mod;
        lazy[rs]+=s;lazy[rs]%=mod;
        a[ls].sum+=s*len(ls);
        a[ls].sum%=mod;
        a[rs].sum+=s*len(rs);
        a[rs].sum%=mod;
        lazy[x]=0;
    }
}
void build(int l,int r,int x)
{
    if(l==r)
    {
        a[x].sum=v[rk[l]],a[x].l=a[x].r=l;
        return ;
    }
    int mid=(l+r)>>1;
    a[x].ls=++cnt;
    a[x].rs=++cnt;
    build(l,mid,a[x].ls);build(mid+1,r,a[x].rs);
    a[x].l=l;a[x].r=r;
    pushup(x);
}
inline void update(int l,int r,int z,int x)
{
    if(a[x].l>=l&&a[x].r<=r)
    {
        (lazy[x]=lazy[x]+z)%=mod;
        (a[x].sum=a[x].sum+len(x)*z)%=mod;
        return ;
    }
    pushdown(x);
    int mid=(a[x].l+a[x].r)>>1;
    if(mid>=l)
    {
        update(l,r,z,a[x].ls);
    }
    if(mid+1<=r)
    {
        update(l,r,z,a[x].rs);
    }
    pushup(x);
}
inline int query(int l,int r,int x)
{
    if(a[x].l>=l&&a[x].r<=r)
    {
        return a[x].sum;
    }
    pushdown(x);
    int mid=(a[x].l+a[x].r)>>1,ans=0;
    if(mid>=l)
    {
        ans+=query(l,r,a[x].ls);
        ans%=mod;
    }
    if(mid+1<=r)
    {
        ans+=query(l,r,a[x].rs);
        ans%=mod;
    }
    return ans%mod;
}
inline int sum(int x,int y)
{
    int r=0;
    while(top[x]!=top[y])
    {
        if(d[top[x]]<d[top[y]])swap(x,y);
        (r+=query(id[top[x]],id[x],rt))%=mod;
        x=f[top[x]];
    }
    if(d[x]>d[y])swap(x,y);
    return (r+query(id[x],id[y],rt))%mod;
}
inline void update1(int x,int y,int z)
{
    while(top[x]!=top[y])
    {
        if(d[top[x]]<d[top[y]])swap(x,y);
        update(id[top[x]],id[x],z,rt);
        x=f[top[x]];
    }
    if(d[x]>d[y])
    {
        swap(x,y);
    }
    update(id[x],id[y],z,rt);
}
signed main()
{
    n=read();m=read();root=read();mod=read();
    for(int i=1;i<=n;i++)
    {
        v[i]=read();
    }
    for(int i=1;i<n;i++)
    {
        int x=read(),y=read();
        add(x,y);
        add(y,x);
    }
    cnt=0;
    dfs1(0,root);dfs2(root,root);
    cnt=0;
    build(1,n,rt=++cnt);
    while(m--)
    {
        int opt=read();
        // cout<<opt<<"\n";
        if(opt==1)
        {
            update1(read(),read(),read());
        }
        else if(opt==2)
        {
            cout<<sum(read(),read())<<"\n";
        }
        else if(opt==3)
        {
            int x=read(),k=read();
            update(id[x],id[x]+sz[x]-1,k,rt);
        }
        else
        {
            int x=read();
            cout<<query(id[x],id[x]+sz[x]-1,rt)<<"\n";
        }
    }
    return 0;
}

by Shadow_Lord @ 2023-07-07 10:45:01

样例输出 2和 4


by _XHY20180718_ @ 2023-07-10 15:00:41

@Shadow_Lord 线段树挂了吧:

if(mid>=l)
    {
        update(l,r,z,a[x].ls);
    }
    if(mid+1<=r)
    {
        update(l,r,z,a[x].rs);
    }

应该是:

if(mid>=l)
    {
        update(l,mid,z,a[x].ls);
    }
    if(mid+1<=r)
    {
        update(mid+1,r,z,a[x].rs);
    }

by _XHY20180718_ @ 2023-07-10 15:02:59

qry也一样


|