树链剖分求指点

P3384 【模板】重链剖分/树链剖分

DESTRUCTION_3_2_1 @ 2023-01-11 13:19:50

// Problem: P3384 【模板】重链剖分/树链剖分
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P3384
// Memory Limit: 125 MB
// Time Limit: 1000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define int long long
#define ls (x * 2)
#define rs (x * 2 + 1)
using namespace std;
const int INF = 2147483647;
const int SZ = 1e5 + 5;

struct segmentTree {
    int l, r;
    int len, sum, tag;
}tr[SZ * 4];
int n, m, r, p, opt, x, y, z, u, v, cnt, maxsize, dfsTime, mid;
int w[SZ], h[SZ * 2], to[SZ * 2], nxt[SZ * 2];
int son[SZ], fa[SZ], sz[SZ], dep[SZ], top[SZ], dfn[SZ], rnk[SZ];

void add (int u, int v) {
    cnt++;
    to[cnt] = v;
    nxt[cnt] = h[u];
    h[u] = cnt;
}
void dfs1 (int u, int f) {
    fa[u] = f, sz[u] = 1, dep[u] = dep[f + 1];
    for (int i = h[u]; i; i = nxt[i]) {
        v = to[i];
        if (v == f) continue;
        dfs1 (v, u);
        sz[u] += sz[v];
        if (sz[v] > sz[son[u]]) son[u] = v;
    }
}
void dfs2 (int u, int t) {
    top[u] = t;
    dfn[u] = ++dfsTime;
    rnk[dfsTime] = u;
    if (!son[u]) return;
    dfs2 (son[u], t);
    for (int i = h[u]; i; i = nxt[i]) {
        v = to[i];
        if (v == fa[u] || v == son[u]) continue;
        dfs2 (v, v);
    }
}
void build (int l, int r, int x = 1) {
    tr[x].l = l, tr[x].l = l, tr[x].len = r - l + 1;
    if (l == r) {
        tr[x].sum = w[rnk[l]];
        return;
    }
    mid = (r + l) / 2;
    build (l, mid, ls);
    build (mid + 1, r, rs);
    tr[x].sum = (tr[ls].sum + tr[rs].sum) % p;
}
void modify (int ql, int qr, int c, int x = 1) {
    int l = tr[x].l, r = tr[x].r;
    if (ql <= l && r <= qr) {
        tr[x].tag += c;
        tr[x].sum += c * tr[x].len;
    }
    if (l > qr || r < ql) return;
    if (tr[x].tag != 0) {
        tr[ls].tag += tr[x].tag, tr[rs].tag += tr[x].tag;
        tr[ls].sum += tr[x].tag * tr[ls].len;
        tr[rs].sum += tr[x].tag * tr[rs].len;
        tr[x].tag = 0;
    }
    modify (ql, qr, c, ls);
    modify (ql, qr, c, rs);
    tr[x].sum = (tr[ls].sum + tr[rs].sum) % p;
}
int query (int ql, int qr, int x = 1) {
    int l = tr[x].l, r = tr[x].r;
    if (ql <= l && r <= qr) return tr[x].sum;
    if (l > qr || r < ql) return 0;
    if (tr[x].tag != 0) {
        tr[ls].tag += tr[x].tag, tr[rs].tag += tr[x].tag;
        tr[ls].sum += tr[x].tag * tr[ls].len;
        tr[rs].sum += tr[x].tag * tr[rs].len;
        tr[x].tag = 0;
    }
    return (query(ql, qr, ls) + query(ql, qr, rs));
}
void mChain (int x, int y, int z) {
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        modify(dfn[top[x]], dfn[x], z);
        x = fa[top[x]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    modify(dfn[x], dfn[y], z);
}
int qChain (int x, int y) {
    int ret = 0;
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        ret += query(dfn[top[x]], dfn[x]);
        x = fa[top[x]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    ret += query(dfn[x], dfn[y]);
    return ret;
}
void mTree (int x, int z) {modify(dfn[x], dfn[x] + sz[x] - 1, z);}
int qTree (int x) {return query(dfn[x], dfn[x] + sz[x] - 1);}
signed main(void)
{
    cin >> n >> m >> r >> p;
    for (int i = 1; i <= n; i++) cin >> w[i];
    for (int i = 1; i <= n - 1; i++) {
        cin >> u >> v;
        add(u, v), add (v, u);
    }
    //cout << "test1" << endl;
    dfs1 (r, 0);
    //cout << "test2" << endl;
    dfs2 (r, r);
    //cout << "test3" << endl;
    build (1, n);
    //cout << "test4" << endl;
    for (int i = 1; i <= m; i++) {
        cin >> opt;
        if (opt == 1) {
            cin >> x >> y >> z;
            mChain(x, y, z);
        }
        if (opt == 2) {
            cin >> x >> y;
            cout << qChain(x, y) << endl;
        }
        if (opt == 3) {
            cin >> x >> z;
            mTree(x, z);
        }
        if (opt == 4) {
            cin >> x;
            cout << qTree(x) << endl;
        }
    }
    return 0;
}

调了半天,求大佬帮忙,谢谢


by zesqwq @ 2023-01-11 13:20:53

modify

    if (ql <= l && r <= qr) {
        tr[x].tag += c;
        tr[x].sum += c * tr[x].len;
    }

对吗,不用 return 是吗


by LYBAKIOI @ 2023-01-11 13:37:47

void dfs1 (int u, int f) {
    fa[u] = f, sz[u] = 1, dep[u] = dep[f + 1];

dep[u] = dep[f + 1];不应该是dep[u] = dep[f] + 1;


by DESTRUCTION_3_2_1 @ 2023-01-11 14:38:15

@LYBAKIOI @zhouershan 感谢


|