WA on 2-10

P3384 【模板】重链剖分/树链剖分

rentianxiang @ 2024-08-19 20:00:00

37pts,2-10WA了,蒟蒻求调

#include<bits/stdc++.h>
using namespace std;
#define pl p << 1
#define pr p << 1 | 1
const int N = 300010;
int n, m, md, rt, a[N], fa[N], dfn[N], dep[N], cnt, rnk[N], hson[N], sz[N], top[N];
int d[N], lt[N];
vector<int> G[N];
void push_up(int s, int t, int p) {
    d[p] = (d[pl] + d[pr]) % md;
}
void push_down(int s, int t, int p) {
    int m = s + ((t - s) >> 1);
    d[pl] = (d[pl] + (m - s + 1) * lt[p] % md) % md;
    d[pr] = (d[pr] + (t - m) * lt[p] % md) % md;
    lt[pl] = (lt[p] + lt[pl]) % md;
    lt[pr] = (lt[p] + lt[pr]) % md;
    lt[p] = 0;
}
void build(int s, int t, int p) {
    if (s == t) {
        d[p] = a[rnk[s]] % md;
        return ;
    }
    int m = s + ((t - s) >> 1);
    build(s, m, pl);
    build(m + 1, t, pr);
    push_up(s, t, p);
}
void update(int l, int r, int c, int s, int t, int p) {
    if (l <= s && t <= r) {
        d[p] = (d[p] + (t - s + 1) * c % md) % md;
        lt[p] = (lt[p] + c) % md;
        return ;
    }
    int m = s + ((t - s) >> 1);
    if (lt[p]) {
        push_down(s, t, p);
    }
    if (l <= m) {
        update(l, r, c, s, m, pl);
    }
    if (r > m) {
        update(l, r, c, m + 1, t, pr);
    }
    push_up(s, t, p);
}
int query(int l, int r, int s, int t, int p) {
    if (l <= s && t <= r) {
        return d[p];
    }
    int m = s + ((t - s) >> 1);
    if (lt[p]) {
        push_down(s, t, p);
    }
    int ret = 0;
    if (l <= m) {
        ret = (ret + query(l, r, s, m, pl)) % md;
    }
    if (r > m) {
        ret = (ret + query(l, r, m + 1, t, pr)) % md;
    }
    return ret;
}
void dfs1(int x, int f) {
    sz[x] = 1;
    hson[x] = -1;
    fa[x] = f;
    dep[x] = dep[f] + 1;
    for (int i = 0; i < G[x].size(); i++) {
        int y = G[x][i];
        if (y == f) continue;
        dfs1(y, x);
        sz[x] += sz[y];
        if (i == 0 || sz[y] > sz[hson[x]]) {
            hson[x] = y;
        }
    } 
}
void dfs2(int x, int tp) {
    top[x] = tp;
    dfn[x] = ++cnt;
    rnk[cnt] = x;
    for (int i = 0; i < G[x].size(); i++) {
        int y = G[x][i];
        if (y == fa[x]) continue;
        if (y == hson[x])   dfs2(hson[x], top[x]);
        else    dfs2(y, y);
    }
}
void update_lca(int x, int y, int c) {
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) {
            swap(x, y);
        }
        update(dfn[top[x]], dfn[x], c, 1, n, 1);
        x = fa[top[x]];
    }
    if (dfn[x] > dfn[y])    swap(x, y);
    update(dfn[x], dfn[y], c, 1, n, 1);
}
int query_lca(int x, int y) {
    int ret = 0;
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) {
            swap(x, y);
        }
        ret = (ret + query(dfn[top[x]], dfn[x], 1, n, 1)) % md;
        x = fa[top[x]];
    }
    if (dfn[x] > dfn[y])    swap(x, y);
    ret = (ret + query(dfn[x], dfn[y], 1, n, 1)) % md;
    return ret;
}
signed main() {
    scanf("%d%d%d%d", &n, &m, &rt, &md);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
    }
    for (int i = 1; i < n; i++) {
        int x, y;
        scanf("%d%d", &x, &y);
        G[x].push_back(y);
        G[y].push_back(x);
    }
    dfs1(rt, rt);
    dfs2(rt, rt);
    build(1, n, 1);
    for (int i = 1; i <= m; i++) {
        int op, x, y, z;
        scanf("%d", &op);
        if (op == 1) {
            scanf("%d%d%d", &x, &y, &z);
            z = z % md;
            update_lca(x, y, z);
        }
        if (op == 2) {
            scanf("%d%d", &x, &y);
            cout << query_lca(x, y) << endl;
        }
        if (op == 3) {
            scanf("%d%d", &x, &y);
            y = y % md;
            update(dfn[x], dfn[x] + sz[x] - 1, y, 1, n, 1);
        }   
        if (op == 4) {
            scanf("%d", &x);
            cout << query(dfn[x], dfn[x] + sz[x] - 1, 1, n, 1) << endl;
        }
    }
    return 0;
}

by rentianxiang @ 2024-08-19 20:11:46

看了楼下大佬的回复,发现dfs2应该优先遍历重儿子,加一句

if (hson[x] != -1) {
    dfs2(hson[x], top[x]);
}

就过啦

感谢大佬@i10eg

此帖已结


|