树剖求调

P3384 【模板】重链剖分/树链剖分

Henly_Z @ 2024-07-29 16:01:35

WA3,AC2,RE6

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int maxn = 200005;
vector<int> g[maxn];
vector<int> num[maxn];
int dep[maxn] , sz[maxn],n,q,fa[maxn],dis[maxn],out[maxn];
int a[maxn] , r,mod;
int nn , rt[maxn] , in[maxn] , path[maxn];
void dfs1(int u , int f) {
    fa[u] = f;
    dep[u] = dep[f]  + 1;
    sz[u] = 1;
    dis[u] = dis[f] + a[u];
    for(int i = 0 ; i < g[u].size(); i++) {
        int v = g[u][i];
        if(v != f) {
            dfs1(v,u);
            sz[u] += sz[v];
        }
    }
}

void dfs2(int u , int c) {
    ++nn;
    path[nn] = u;
    in[u] = nn;
    rt[u] = c;
    int hson = 0;
    for(auto v : g[u]) if(sz[v] < sz[u]) {
            if(sz[v] > sz[hson]) {
                hson = v;
            }
        }
    if(hson) {
        dfs2(hson,c);
        for(auto v : g[u]) {
            if(rt[v] == 0) {
                dfs2(v,v);
            }
        }
    } 
}
int lca(int u ,int v) {
    while(rt[u] != rt[v]) {
        if(dep[rt[u]] >= dep[rt[v]]) {
            u = fa[rt[u]];
        } else {
            v = fa[rt[v]];
        }
    }
    return dep[u] <= dep[v] ? u : v;
}

struct S {
    int sum, mx, len;
    int lazyd;
    void reset() {
        lazyd = 0;
    }
} segt[maxn * 4];
S merge(const S& s1, const S& s2) {
    S s;
    s.reset();
    s.len = s1.len + s2.len;
    s.sum = s1.sum + s2.sum;
    s.sum %= mod;
    s.mx = max(s1.mx, s2.mx);
    return s;
}
void apply(S& s, const S& tag) {
    s.lazyd += tag.lazyd;
    s.sum += tag.lazyd * s.len;
    s.sum %= mod;
    s.mx += tag.lazyd;
}

void build(int idx, int L, int R) {
    segt[idx].lazyd = 0;
    if (L == R) {
        segt[idx].len = 1;
        segt[idx].sum = segt[idx].mx = 0;
        return;
    }
    int M = (L+R) / 2;
    build(idx*2, L, M);
    build(idx*2+1, M+1, R);
    segt[idx] = merge(segt[idx*2], segt[idx*2+1]);
}
int ql, qr;
S query(int idx, int L, int R) {
    if (ql <= L && R <= qr) {
        return segt[idx];
    }
    int M = (L+R) / 2;
    apply(segt[idx*2], segt[idx]);
    apply(segt[idx*2+1], segt[idx]);
    segt[idx].reset();

    if (qr <= M) return query(idx*2, L, M);
    if (ql > M) return query(idx*2+1, M+1, R);
    return merge(query(idx*2, L, M), query(idx*2+1, M+1, R));
}
void update(int idx, int L, int R, const S& tag) {
    if (ql <= L && R <= qr) {
        apply(segt[idx], tag);
        return;
    }
    int M = (L+R) / 2;
    apply(segt[idx*2], segt[idx]);
    apply(segt[idx*2+1], segt[idx]);
    segt[idx].reset();

    if (ql <= M)
        update(idx*2, L, M,tag);
    if (qr > M)
        update(idx*2+1, M+1, R, tag);
    segt[idx] = merge(segt[idx*2], segt[idx*2+1]);
}
int qrange(int x , int y){
    int sum = 0;
    while(rt[x] != rt[y]){
        if(dep[rt[x]] < dep[rt[y]]){
            swap(x,y);
        }
        ql = in[rt[x]];
        qr = in[x];
        sum += query(r,1,2 * n).sum;
        sum %= mod;
        x = fa[rt[x]];
    }
    if(dep[x] > dep[y]) swap(x,y);
    ql = in[x];
    qr = in[y];
    sum += query(r,1,2 * n).sum;
    return sum; 
}
void upd(int x , int y , S c){
    while(rt[x] != rt[y]){
        if(dep[rt[x]] < dep[rt[y]]){
            swap(x,y);
        }
        ql = in[rt[x]];
        qr = in[x];
        update(r,1,2 * n,c);
        x = fa[rt[x]];
    }
    if(dep[x] > dep[y]) swap(x,y);
    ql = in[x];
    qr = in[y];
    update(r,1,n,c);
}
signed main() {
    cin >> n >> q >> r >> mod;
    for(int i = 1 ; i <= n ; i++) {
        cin >> a[i];
    }
    for(int i = 1 ; i < n ; i++) {
        int u , v;
        scanf("%lld %lld",&u,&v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    build(r,1,2*n);
    dfs1(r,0);

    dfs2(r,r);
    for(int i = 1 ; i <= n ; i++) {
        ql = in[i];
        qr = in[i];
        S tag;
        tag.lazyd = a[i];
        update(r,1,2 * n,tag);
    }
    while(q--) {
        int op , x , y , z;
        cin >> op;
        if(op == 1) {
            cin >> x >> y >> z;
            ql = out[x];
            qr = in[y];
            S tag;
            tag.lazyd = z;
            upd(x,y,tag);
        } else if(op == 2) {
            cin >> x >> y;
            cout << qrange(x,y) % mod<<endl;
        } else if(op == 3) {
            cin >>x >> z;
            ql = in[x];
            qr = in[x] + sz[x] - 1;
//          cout << ql << " " << qr << endl;
            S tag;
            tag.lazyd = z;
            update(r,1,2 * n,tag);
        } else {
            cin >> x;
            ql = in[x];
            qr = in[x] + sz[x] - 1;
            cout << query(r,1,2 * n).sum% mod<< endl;
        }
    }
}

|