样例没过,求调QAQ(可能是线段树)

P3384 【模板】重链剖分/树链剖分

_int128 @ 2024-07-27 12:02:53

线段树不太熟练,但是这样写看不出错

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e5+100,M=1e5+100;
ll cnt=0,n,m,r,p,scnt=0,wi[N];
ll w[N],head[N],fir[N];
ll depth[N],size[N],son[N],fa[N],id[N],top[N];
struct Edge{
    ll to,ne;
}e[2*M];
struct Tree{
    ll l,r,tag,dat;
}st[4*N];
void Add(int a,int b){
    e[++cnt].ne=fir[a];
    e[cnt].to=b;
    fir[a]=cnt;
}
void add(int node,int l,int r,int y){
    if(l<=st[node].l&&r>=st[node].r){
        st[node].tag+=y;
        return ;
    }
    st[node].dat+=(y*(r-l+1))%p;
    st[node].dat%=p; 
    int mid=(st[node].l+st[node].r)/2;
    if(l<=mid){
        add(node*2,l,mid,y);
    }if(r>mid){
        add(node*2+1,mid+1,r,y);
    }
}
void push_down(int node){
    if(st[node].tag==0)return;
    st[node].dat+=(st[node].tag*((st[node].r-st[node].l+1)%p))%p;
    st[node].dat%=p;
    st[node*2].tag+=st[node].tag;
    st[node*2].tag%=p;
    st[node*2+1].tag+=st[node].tag;
    st[node*2+1].tag%=p;
    return;
}
ll sum(int node,int l,int r){
    push_down(node);
    ll u=0,v=0;
    if(l<=st[node].l&&r>=st[node].r){
        return st[node].dat;
    }
    int mid=(st[node].l+st[node].r)/2;
    if(l<=mid){
        u=sum(node*2,l,mid);
    }if(r>mid){
        v=sum(node*2+1,mid+1,r);
    }
    return u+v;
}
void build(int node,int l,int r){
    st[node].l=l;st[node].r=r;
    st[node].tag=0;
    if(l==r){
        st[node].dat=wi[r];
        st[node].dat%=p; 
        return;
    }
    int mid=(l+r)/2;
    build(2*node,l,mid);
    build(2*node+1,mid+1,r);
    st[node].dat=st[node*2+1].dat+st[node*2].dat;
    st[node].dat%=p;
}

void dfs1(int u,int f){
    fa[u]=f;
    depth[u]=depth[f]+1;
    size[u]=1;
    int t,v;
    for(int i=fir[u];i;i=e[i].ne){
        v=e[i].to;
        if(v==f) continue;
        dfs1(v,u);
        size[u]+=size[v];
        if(size[v]>t){
            t=size[v];
            son[u]=v;
        }
    }
}
void dfs2(int u,int f){
    top[u]=f;
    id[u]=++scnt;
    wi[scnt]=w[u];
    if(!son[u])return ;
    dfs2(son[u],f);
    for(int i=fir[u];i;i=e[i].ne){
        int v=e[i].to;
        if(v==f||v==son[u]) continue;
        dfs2(v,v);
    }
}
ll query_path(int u,int v){
    ll o=0;
    while(top[u]!=top[v]){
        if(depth[top[u]]<depth[top[v]]) swap(u,v);
        o+=sum(1,id[top[u]],id[u])%p;
        o%=p;
        u=fa[top[u]];
    }
    if(depth[top[u]]>depth[top[v]]) swap(u,v);
    o+=sum(1,id[u],id[v]);
    o%=p;
    return o;
}
int add_path(int u,int v,int k){
    while(top[u]!=top[v]){
        if(depth[top[u]]<depth[top[v]]) swap(u,v);
        add(1,id[top[u]],id[u],k);
        u=fa[top[u]];
    }
    if(depth[u]>depth[v]) swap(u,v);
    add(1,id[u],id[v],k);
}
int query_son(int u){
    return sum(1,id[u],id[u]+size[u]-1);
}
void add_son(int u,int k){
    add(1,id[u],id[u]+size[u]-1,k);
}
int main(){
    cin>>n>>m>>r>>p;
    for(int i=1;i<=n;i++){
        cin>>w[i];
    }
    int u,v;
    for(int i=1;i<n;i++){
        cin>>u>>v;
        Add(u,v);
        Add(v,u);
    }
    dfs1(r,-1);
    dfs2(r,r);
    build(1,1,n);
    while(m--){
        int k,x,y,z;
        cin>>k;
        if(k==1){
            cin>>x>>y>>z;
            add_path(x,y,z);
        }else if(k==2){
            cin>>x>>y;
            cout<<query_path(x,y)<<endl;
        }else if(k==3){
            cin>>x>>y;
            add_son(x,y);
        }else{
            cin>>x;
            cout<<query_son(x)<<endl;
        }
    }
    return 0;
}

by Chase12345 @ 2025-01-01 19:14:18

区间加没向下传递@_int128


|