树剖T了3 个点 求调

P3384 【模板】重链剖分/树链剖分

syLph @ 2023-12-30 11:32:04

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstring>
#include<memory.h>
#include<vector>

using namespace std;

const int maxn = 2000010;

typedef long long ll;

ll nxt[maxn],head[maxn],ver[maxn],tot = 1;
ll dep[maxn],fa[maxn],siz[maxn],son[maxn];
ll top[maxn],rnk[maxn],cnt,dfn[maxn],a[maxn],w[maxn];
ll lazy[maxn],dat[maxn];
ll n,m,root,mod;

void pushdown(ll p,ll l,ll r){
    if(lazy[p]){
        ll mid = (l + r) >> 1;
        lazy[p*2] += lazy[p]; dat[p*2] += lazy[p] * (mid - l + 1);
        lazy[p*2+1] += lazy[p]; dat[p*2+1] += lazy[p] * (r - mid);
        lazy[p] = 0;
        lazy[p*2] = lazy[p*2] % mod;
        lazy[p*2+1] = lazy[p*2+1] % mod;
    }
}
void build(ll p,ll l,ll r){
    if(l == r){
        dat[p] = w[l];
        return;
    }
    ll mid = (l + r) >> 1;
    if(l<=mid) build(p*2,l,mid);
    if(mid+1<=r) build(p*2+1,mid+1,r);
    dat[p] = dat[p*2] + dat[p*2+1];
    dat[p] = dat[p] % mod;
}

void update(ll p,ll l,ll r,ll L,ll R,ll k){
    if(l>R||r<L) return;
    if(L<=l&&r<=R){
        dat[p] += (r - l + 1) * k; lazy[p] += k;
        dat[p] = dat[p] % mod;
        pushdown(p,l,r);
        return;
    }
    pushdown(p,l,r);
    ll mid = (l + r) >> 1;
    if(l<=mid) update(p*2,l,mid,L,R,k);
    if(mid+1<=r) update(p*2+1,mid+1,r,L,R,k);
    dat[p] = dat[p*2] + dat[p*2+1];
    dat[p] = dat[p] % mod;
}

ll getsum(ll p,ll l,ll r,ll L,ll R){
    if(l>R||r<L) return 0;
    if(L<=l&&r<=R) {
        return dat[p];
    }
    pushdown(p,l,r);
    ll mid = (l + r) >> 1,ret = 0;
    if(l<=mid) ret += getsum(p*2,l,mid,L,R);
    if(mid+1<=r) ret += getsum(p*2+1,mid+1,r,L,R);
    return ret % mod;
}

void add(ll x,ll y){
    ver[++tot] = y;
    nxt[tot] = head[x];
    head[x] = tot;
}

void dfs1(ll x){
    son[x] = -1,siz[x] = 1;
    for(ll i = head[x] ; i ; i = nxt[i]){
        ll y = ver[i];
        if(!dep[y]&&y!=fa[x]){
            dep[y] = dep[x] + 1;
            fa[y] = x;
            dfs1(y); 
            siz[x] += siz[y];
            if(son[x]==-1&&siz[ver[i]]>siz[son[x]]) son[x] = ver[i];
        }
    }
}
void dfs2(ll x,ll t){
    top[x] = t; dfn[x] = ++cnt; w[cnt] = a[x];
    if(son[x] == -1) return;
    dfs2(son[x],t);
    for(ll i = head[x] ; i ; i = nxt[i])
        if(ver[i]!=son[x]&&ver[i]!=fa[x]) 
            dfs2(ver[i],ver[i]);
}
ll lca(ll x,ll y){
    while(top[x] != top[y]){
        if(dep[top[x]] > dep[top[y]]) x = fa[top[x]];
        else y = fa[top[y]];
    }
    return dep[x] > dep[y] ? y : x;
}
ll query1(ll x,ll y){
    ll ans = 0,res;
    while(top[x] != top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        ans += getsum(1,1,n,dfn[top[x]],dfn[x]);
        ans = (ans + mod) % mod;
        x = fa[top[x]];
    }
    if(dep[x] > dep[y]) swap(x,y);
    ans = (ans + getsum(1,1,n,dfn[x],dfn[y]) + mod) % mod;
    return (ans + mod) % mod;
}

void query2(ll x,ll y,ll k){
    k = (k + mod) % mod;
    while(top[x] != top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        update(1,1,n,dfn[top[x]],dfn[x],k);
        x = fa[top[x]];
    }
    if(dep[x] > dep[y]) swap(x,y);
    update(1,1,n,dfn[x],dfn[y],k);
}
ll query3(ll x){
    ll ans = getsum(1,1,n,dfn[x],dfn[x]+siz[x]-1);
    return (ans + mod) % mod;
}
void query4(ll x,ll k){
    update(1,1,n,dfn[x],dfn[x]+siz[x]-1,k);
    return;
}
int main(){
    //freopen("testdata.in","r",stdin);
    //freopen("testdata.out","w",stdout);
    scanf("%lld%lld%lld%lld",&n,&m,&root,&mod);
    for(ll i = 1 ; i <= n ; i ++) scanf("%lld",&a[i]);
    for(ll i = 1 ; i <= n - 1 ; i ++){
        ll u,v; scanf("%lld%lld",&u,&v); 
        add(u,v),add(v,u);
    }
    dep[root] = 1;
    dfs1(root); 
    dfs2(root,root);
    build(1,1,n);
    for(ll i = 0 ; i < m ; i ++){
        ll opt; scanf("%lld",&opt);
        if(opt == 1){
            ll x,y,z; scanf("%lld%lld%lld",&x,&y,&z);
            query2(x,y,z);
        }else if(opt == 2){
            ll x,y; scanf("%lld%lld",&x,&y);
            printf("%lld\n",query1(x,y));
        }else if(opt == 3){
            ll x,y; scanf("%lld%lld",&x,&y);
            query4(x,y);
        }else{
            ll x; scanf("%lld",&x);
            printf("%lld\n",query3(x));
        }
    }
    return 0;
}

不知道为什么…… 8,9,10 T 了


by _zzzzzzy_ @ 2023-12-30 11:33:10

@syLph 别都用ll试试,空间也别开那莫大


by syLph @ 2023-12-30 12:11:39

@zzzzzzy 似乎还是不行


by _zzzzzzy_ @ 2023-12-30 12:15:08

@syLph 根据你写的线段数函数有很多错误的,但是运行出来是对的,我估计是没有学得很好,然后错误大概率在线段树上


|