37pts 求调

P3384 【模板】重链剖分/树链剖分

ek7a @ 2023-07-21 16:41:00

#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for (int i = (a); i <= (b); ++i)
#define int long long
const int N = 1e5 + 10;
int n, m, p;
int a[N];
vector<int> G[N];
int cnt = 0, rt;
int top[N], node[N];
int sz[N], dep[N], fa[N];
int son[N], dfn[N];
void dfs1(int u, int p) {
    sz[u] = 1;
    dep[u] = dep[p] + 1;
    fa[u] = p;
    for (int v : G[u]) {
        if (v == p) continue;
        dfs1(v, u);
        sz[u] += sz[v];
        if (sz[v] > sz[son[u]]) son[u] = v;
    }
}
void dfs2(int u, int tp) {
    top[u] = tp;
    dfn[u] = ++cnt;
    node[cnt] = u;
    if (son[u]) dfs2(son[u], tp);
    for (int v : G[u]) if (!dfn[v])
        dfs2(v, v);
}
namespace sgt {
    #define lc x << 1, l, mid
    #define rc x << 1 | 1, mid + 1, r
    #define st int x = 1, int l = 1, int r = n
    int sum[N << 2], tag[N << 2];
    void psp(int x) {
        sum[x] = sum[x << 1] + sum[x << 1 | 1];
        sum[x] %= p;
    }
    void build(st) {
        if (l == r) {
            sum[x] = a[node[l]] % p;
            return;
        }
        int mid = l + r >> 1;
        build(lc); build(rc);
        psp(x);
    }
    void pst(int x, int l, int r, int v) {
        tag[x] += v;
        tag[x] %= p;
        sum[x] += v * (r - l + 1);
        sum[x] %= p;
    }
    void psd(st) {
        if (tag[x]) {
            int mid = l + r >> 1;
            pst(lc, tag[x]); pst(rc, tag[x]);
            tag[x] = 0;
        }
    }
    void upd(int L, int R, int v, st) {
        if (l > R || r < L) return;
        if (L <= l && r <= R) {
            return pst(x, l, r, v);
        }
        psd(x, l, r);
        int mid = l + r >> 1;
        upd(L, R, v, lc); upd(L, R, v, rc);
        psp(x);
    }
    int qry(int L, int R, st) {
        if (l > R || r < L) return 0;
        if (L <= l && r <= R) return sum[x];
        psd(x, l, r);
        int mid = l + r >> 1;
        return (qry(L, R, lc) + qry(L, R, rc)) % p; 
    }
}
void mdf1(int x, int y, int v) {
    while (top[x] != top[y]) {
        if (dep[x] < dep[y]) swap(x, y);
        sgt::upd(dfn[top[x]], dfn[x], v);
        x = fa[top[x]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    sgt::upd(dfn[x], dfn[y], v);
}
int qry1(int x, int y) {
    int ret = 0;
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        ret = (ret + sgt::qry(dfn[top[x]], dfn[x])) % p;
        x = fa[top[x]];
    }
    if(dep[x] > dep[y]) swap(x, y);
    ret = (ret + sgt::qry(dfn[x], dfn[y])) % p;
    return ret;
}
void mdf2(int x, int v) {
    sgt::upd(dfn[x], dfn[x] + sz[x] - 1, v);
}
int qry2(int x) {
    return sgt::qry(dfn[x], dfn[x] + sz[x] - 1);
}
signed main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> m >> rt >> p;
    rep(i, 1, n) cin >> a[i];
    rep(i, 1, n - 1) {
        int u, v;
        cin >> u >> v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs1(rt, 0);
    dfs2(rt, rt);
    sgt::build();
    rep(i, 1, m) {
        int op, x, y, z; 
        cin >> op;
        if (op == 1) {
            cin >> x >> y >> z;
            mdf1(x, y, z);
        } else if (op == 2) {
            cin >> x >> y;
            cout << qry1(x, y) << '\n';
        } else if (op == 3) {
            cin >> x >> z;
            mdf2(x, z);
        } else {
            cin >> x;
            cout << qry2(x) << '\n';
        }
    }
    return 0;
}

by _XHY20180718_ @ 2023-07-21 17:09:39

@ek7a

if (dep[x] < dep[y]) swap(x, y);

改成:

if (dep[top[x]] < dep[top[y]]) swap(x, y);

by ek7a @ 2023-07-21 17:56:51

谢谢!!!


|