求各位大佬帮我调一下,树链剖分模板样例就是过不了,奔溃了

P3384 【模板】重链剖分/树链剖分

yuyudong @ 2024-07-22 17:53:55

#include<bits/stdc++.h>
using namespace std;
const int maxn=10000005;
int n,m,root,mod,siz[maxn],f[maxn],top[maxn],id[maxn],wt[maxn],tag[maxn];
int a[maxn],sum[maxn],vis[maxn],son[maxn],dep[maxn];
int e;
int head[maxn],edge[maxn],nxt[maxn],cnt=0,R[maxn],L[maxn];
void add(int x,int y)
{
    edge[++cnt]=y;nxt[cnt]=head[x];head[x]=cnt;
}
void build(int p,int l,int r)
{
    if(l==r) 
    {
        sum[p]=wt[l];
        return;
    }
    int mid=l+r>>1;
    build(p*2,l,mid);
    build(p*2+1,mid+1,r);
    sum[p]=sum[p*2]+sum[p*2+1];
}
void maketag(int p,int z,int l,int r)
{
    tag[p]+=z;
    sum[p]+=(r-l+1)*z; 
}
void pushdown(int p,int l,int r)
{
    if(tag[p])
    {
        maketag(p*2,tag[p],l,r);
        maketag(p*2+1,tag[p],l,r);
        tag[p]=0;
    }
}
int query(int p,int l,int r,int x,int y)
{
    if(x<=l&&r<=y)
    {
        return sum[p]; 
    }
    pushdown(p,l,r);
    int mid=l+r>>1;
    int res=0;
    if(x<=mid) res+=query(p*2,l,mid,x,y);
    if(mid<y) res+=query(p*2+1,mid+1,r,x,y);
    return res;
}
void update(int p,int l,int r,int x,int y,int z)
{
    if(x<=l&&r<=y)
    {
        sum[p]+=(r-l+1)*z;
        tag[p]+=z;
        return;
    }
    pushdown(p,l,r);
    int mid=l+r>>1;
    if(x<=mid) update(p*2,l,mid,x,y,z);
    if(mid<y) update(p*2+1,mid+1,y,x,y,z);
    sum[p]=sum[p*2]+sum[p*2+1];
}
//线段树 
void dfs1(int u,int fa)
{
    siz[u]=1;
    f[u]=fa;
    dep[u]=dep[fa]+1;
    int maxson=-1;
    for(int i=head[u];i;i=nxt[i])
    {
        int v=edge[i];
        if(v==fa) continue;
        dfs1(v,u);
        siz[u]+=siz[v];
        if(siz[v]>maxson)
        {
            maxson=siz[v];
            son[u]=v;
        }
    }
}
void dfs2(int u,int topp)
{
    id[u]=++e;
    top[u]=topp;
    wt[e]=a[u];
    if(son[u]==0) return;
    dfs2(son[u],topp);
    for(int i=head[u];i;i=nxt[i])
    {
        int v=edge[i];
        if(v!=son[u]&&v!=f[u])
        {
            dfs2(v,v);
        }
    }
}
int  qroad(int x,int y)
{
    int ans=0;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]]){
            swap(x,y);
        }
        ans+=query(1,1,n,id[top[x]],id[x]);ans%=mod;
        x=f[top[x]];
    }
    if(dep[x]>dep[y]){
        swap(x,y);
    }
    ans+=query(1,1,n,id[x],id[y]);ans%=mod;
    return ans;
}
void updroad(int x,int y,int z)
{
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]])
        {
            swap(x,y);
        }
        update(1,1,n,id[top[x]],id[x],z);
        x=f[top[x]];
    }
    if(dep[x]>dep[y])
    {
        swap(x,y); 
     } 
    update(1,1,n,id[x],id[y],z);
}
void updtree(int u,int z)
{
    update(1,1,n,id[u],id[u]+siz[u]-1,z);
}
int qtree(int u)
{
    int res=query(1,1,n,id[u],id[u]+siz[u]-1);
    res%=mod;
    return res;
}
int main()
{
    cin>>n>>m>>root>>mod;
    for(int i=1;i<=n;i++){
        cin>>a[i];
    }
    for(int i=1;i<n;i++)
    {
        int x,y;
        cin>>x>>y;
        add(x,y);
        add(y,x);
    }
    dfs1(root,0);
    dfs2(root,root);    
    build(1,1,n);
    for(int i=1;i<=m;i++ )
    {
        int opt;
        cin>>opt;
        int x,y,z;
        if(opt==1)
        {
            cin>>x>>y>>z;
            updroad(x,y,z);
        }
        else if(opt==2)
        {
            cin>>x>>y;
            cout<<qroad(x,y)<<endl;
        }
        else if(opt==3)
        {
            cin>>x>>z;
            updtree(x,z);
        }
        else 
        {
            cin>>x;
            cout<<qtree(x)<<endl;
        }
    }
    return 0;
}

by Jonny09 @ 2024-07-22 19:40:21

1.

if(mid<y) update(p*2+1,mid+1,y,x,y,z);

改成

if(mid<y) update(p*2+1,mid+1,r,x,y,z);

2.你的pushdown函数写错了,maketag的后两个参数分别是l,mid和mid+1,r,不是l和r


by yuyudong @ 2024-07-24 09:47:30

谢谢大佬


|