玄关求条

P3384 【模板】重链剖分/树链剖分

AC_Boy @ 2024-08-27 17:05:24

Link


by yangwuqi @ 2024-08-27 19:15:48

#include <bits/stdc++.h>
#define int long long
#define lc (root*2)
#define rc (root*2+1)
using namespace std;
const int N = 1e5 + 10;
vector<int> edge[N];
int n, m, s, cnt, MOD;
int a[N], fa[N], dep[N], siz[N], hson[N], head[N], dfn[N], to[N];

struct Node {
    int lft, rgt, sum, lazy;
} seg[4 * N];

int dfs1(int root, int fath) {
    fa[root] = fath;
    dep[root] = dep[fath] + 1;
    int tot = 0, maxw = -1, maxid = 0;
    for (int i = 0; i < edge[root].size(); i++) {
        int v = edge[root][i];
        if (v != fath) {
            int node = dfs1(v, root);
            tot += node;
            if (node > maxw) {
                maxw = node;
                maxid = v;
            }
        }
    }
    hson[root] = maxid;
    return siz[root] = tot + 1;
}

void dfs2(int root, int t) {
    head[root] = t;
    cnt++;
    dfn[root] = cnt;
    to[cnt] = a[root];
    if (!hson[root])
        return;
    dfs2(hson[root], t);
    for (int i = 0; i < edge[root].size(); i++) {
        int v = edge[root][i];
        if (v != fa[root] and v != hson[root]) {
            dfs2(v, v);
        }
    }
}
//树链剖分

void build(int root, int l, int r) {
    seg[root].rgt = r;
    seg[root].lft = l;
    seg[root].lazy = 0;
    if (l == r) {
        seg[root].sum = to[l] % MOD;
        return ;
    }
    int mid = (l + r) / 2;
    build(lc, l, mid);
    build(rc, mid + 1, r);
    seg[root].sum = (seg[lc].sum + seg[rc].sum) % MOD;
}

void pushdown(int root) {
    seg[lc].sum = (seg[lc].sum + (seg[lc].rgt - seg[lc].lft + 1) * seg[root].lazy % MOD) % MOD;
    seg[rc].sum = (seg[rc].sum + (seg[rc].rgt - seg[rc].lft + 1) * seg[root].lazy % MOD) % MOD;
    seg[lc].lazy = (seg[lc].lazy + seg[root].lazy) % MOD;
    seg[rc].lazy = (seg[rc].lazy + seg[root].lazy) % MOD;
    seg[root].lazy = 0;
}

void add(int root, int l, int r, int k) {
    if (l > seg[root].rgt or r < seg[root].lft)
        return ;
    if (l <= seg[root].lft and seg[root].rgt <= r) {
        seg[root].lazy = (seg[root].lazy + k) % MOD;
        seg[root].sum = (seg[root].sum + (seg[root].rgt - seg[root].lft + 1) * k % MOD) % MOD;
        return ;
    }
    if (seg[root].lazy)
        pushdown(root);
    add(lc, l, r, k);
    add(rc, l, r, k);
    seg[root].sum = (seg[lc].sum + seg[rc].sum) % MOD;
}

int query(int root, int l, int r) {
    if (l > seg[root].rgt or r < seg[root].lft)
        return 0;
    if (l <= seg[root].lft and seg[root].rgt <= r) {
        return seg[root].sum % MOD;
    }
    if (seg[root].lazy)
        pushdown(root);
    return (query(lc, l, r) + query(rc, l, r)) % MOD;
}
//线段树

void addlian(int x, int y, int k) {
    while (head[x] != head[y]) {
        if (dep[head[x]] < dep[head[y]]) {
            swap(x, y);
        }
        add(1, dfn[head[x]], dfn[x], k);
        x = head[x];
        x = fa[x];
    }
    if (dep[x] < dep[y]) {
        add(1, dfn[x], dfn[y], k);
    } else {
        add(1, dfn[y], dfn[x], k);
    }
}

void addtree(int root, int k) {
    add(1, dfn[root], dfn[root] + siz[root] - 1, k);
}

int querylian(int x, int y) {
    int ans = 0;
    while (head[x] != head[y]) {
        if (dep[head[x]] < dep[head[y]]) {
            swap(x, y);
        }
        ans = (ans + query(1, dfn[head[x]], dfn[x])) % MOD;
        x = head[x];
        x = fa[x];
    }
    if (dep[x] < dep[y]) {
        ans = (ans + query(1, dfn[x], dfn[y])) % MOD;
    } else {
        ans = (ans + query(1, dfn[y], dfn[x])) % MOD;
    }
    return ans;
}

int querytree(int root) {
    return query(1, dfn[root], dfn[root] + siz[root] - 1) % MOD;
}

signed main() {
    cin >> n >> m >> s >> MOD;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        edge[u].push_back(v);
        edge[v].push_back(u);
    }
    dfs1(s, s);
    dfs2(s, s);
    build(1, 1, n);
    for (int i = 1; i <= m; i++) {
        int op, x, y, z;
        cin >> op;
        if (op == 1) {
            cin >> x >> y >> z;
            addlian(x, y, z);
        }
        if (op == 2) {
            cin >> x >> y;
            cout << querylian(x, y) << endl;
        }
        if (op == 3) {
            cin >> x >> z;
            addtree(x, z);
        }
        if (op == 4) {
            cin >> x;
            cout << querytree(x) << endl;
        }
    }
    return 0;
}

by yangwuqi @ 2024-08-27 19:26:29

  1. query 里的 if (!seg[root].lazy)

  2. 所有 to[x] 改成 dfn[x] ,应该是 dfn[x] + siz[x] - 1 而不是 dfn[x + siz[x] - 1]

  3. 链那里先跳的应该是重链顶部深度较大的,而不是结点深度较大的否则会出现这样的情况:

求 4-7 的和,3 会被重复算 ,1 会被多算


by yangwuqi @ 2024-08-27 19:29:28

还有提问之前先过样例/看讨论区的Hack和警示后人是一种美德,不要连样例都不过就发求助/lb


by AC_Boy @ 2024-10-02 16:32:37

@yangwuqi Orz


|