最后一个点WA了,找不出来错哪儿了,求调QwQ

P3384 【模板】重链剖分/树链剖分

__shadow__ @ 2023-10-09 17:37:58

rt


by __shadow__ @ 2023-10-09 17:38:53

https://www.luogu.com.cn/record/128419309

#include <cstdio>
#include <vector>
#define ll long long
using namespace std;
const int N = 1e5 + 5;
int n, m, r;
ll p, a[N], b[N];
vector <int> V[N];
int fa[N], dep[N], siz[N], hson[N], id[N], cnt, top[N];
void dfs1(int sx, int ffa) {
    fa[sx] = ffa;
    dep[sx] = dep[ffa] + 1;
    int maxn = -1;
    siz[sx] = 1;
    for (auto to : V[sx]) {
        if (to == ffa) continue;
        dfs1(to, sx);
        siz[sx] += siz[to];
        if (siz[to] > maxn) {
            maxn = siz[to];
            hson[sx] = to;
        }
    }
}
void dfs2(int sx, int topf) {
    id[sx] = ++ cnt;
    top[sx] = topf;
    b[cnt] = a[sx];
    if (hson[sx])
        dfs2(hson[sx], topf);
    for (auto to : V[sx]) {
        if (to == fa[sx] || to == hson[sx]) continue;
        dfs2(to, to);
    }
}
struct sgt {
    ll sum[N << 2], lz[N << 2];
    void PushUp(int u) {
        sum[u] = (sum[u << 1] + sum[u << 1 | 1]) % p;
        return ;
    }
    void build(int u, int l, int r) {
        lz[u] = 0;
        if (l == r) {
            sum[u] = b[l];
            return ;
        }
        int mid = l + r >> 1;
        build (u << 1, l, mid);
        build (u << 1 | 1, mid + 1, r);
        PushUp (u);
        return ;
    }
    void PushDown(int u, int s) {
        if (lz[u] == 0) return ;
        (lz[u << 1] += lz[u]) %= p;
        (lz[u << 1 | 1] += lz[u]) %= p;
        (sum[u << 1] += lz[u] * (s - (s >> 1))) %= p;
        (sum[u << 1 | 1] += lz[u] * (s >> 1)) %= p;
        lz[u] = 0;
        return ;
    }
    void modify(int u, int l, int r, int L, int R, ll k) {
        if (L <= l && r <= R) {
            (sum[u] += k * (r - l + 1)) %= p;
            (lz[u] += k) %= p;
            return ;
        }
        PushDown (u, r - l + 1);
        int mid = l + r >> 1;
        if (L <= mid)
            modify (u << 1, l, mid, L, R, k);
        if (mid < R)
            modify (u << 1 | 1, mid + 1, r, L, R, k);
        PushUp (u);
        return ;
    }
    ll query(int u, int l, int r, int L, int R) {
        if (L <= l && r <= R)
            return sum[u];
        PushDown(u, r - l + 1);
        ll ans = 0;
        int mid = l + r >> 1;
        if (L <= mid)
            ans += query(u << 1, l, mid, L, R);
        if (mid < R)
            ans += query(u << 1 | 1, mid + 1, r, L, R);
        return ans % p;
    }
    void mod_1(int x, int y, ll k) {
        while (top[x] != top[y]) {
            if (dep[top[x]] > dep[top[y]]) {
                modify (1, 1, n, id[top[x]], id[x], k);
                x = fa[top[x]];
            }
            else {
                modify (1, 1, n, id[top[y]], id[y], k);
                y = fa[top[y]];
            }
        }
        if (dep[x] > dep[y])
            modify (1, 1, n, id[y], id[x], k);
        else modify (1, 1, n, id[x], id[y], k);
        return ;
    }
    ll que_1(int x, int y) {
        ll ans = 0;
        while (top[x] != top[y]) {
            if (dep[top[x]] > dep[top[y]]) {
                (ans += query (1, 1, n, id[top[x]], id[x])) %= p;
                x = fa[top[x]];
            }
            else {
                (ans += query (1, 1, n, id[top[y]], id[y])) %= p;
                y = fa[top[y]];
            }
        }
        if (dep[x] > dep[y])
            (ans += query (1, 1, n, id[y], id[x])) %= p;
        else (ans += query (1, 1, n, id[x], id[y])) %= p;
        return ans;
    }
}tre;
int main() {
    scanf ("%d%d%d%lld", &n, &m, &r, &p);
    for (int i = 1;i <= n; ++ i)
        scanf ("%lld", a + i);
    for (int i = 1;i < n; ++ i) {
        int u, v;
        scanf ("%d%d", &u, &v);
        V[u].push_back(v);
        V[v].push_back(u);
    }
    dfs1(r, r);
    dfs2(r, r);
    tre.build(1, 1, n);
    while (m --) {
        int opt;
        scanf ("%d", &opt);
        if (opt == 1) {
            int x, y;
            ll z;
            scanf ("%d%d%lld", &x, &y, &z);
            tre.mod_1 (x, y, z);
            continue;
        }
        if (opt == 2) {
            int x, y;
            scanf ("%d%d", &x, &y);
            printf ("%lld\n", tre.que_1 (x, y));
            continue;
        }
        if (opt == 3) {
            int x;
            ll k;
            scanf ("%d%lld", &x, &k);
            tre.modify(1, 1, n, id[x], id[x] + siz[x] - 1, k);
            continue;
        }
        if (opt == 4) {
            int x;
            scanf ("%d", &x);
            printf ("%lld\n", tre.query(1, 1, n, id[x], id[x] + siz[x] - 1));
            continue;
        }
    }
    return 0;
}

by crimson000 @ 2023-10-09 18:18:14

可能有负数?尝试一下(+mod)%mod


by Mathew_Miao @ 2023-10-09 18:20:02

第 150、163 printf 时 mod p

printf ("%lld\n", tre.que_1 (x, y) % p);
printf ("%lld\n", tre.query(1, 1, n, id[x], id[x] + siz[x] - 1) % p);

by Digital_Sunrise @ 2023-10-13 22:45:04

sgt build 函数中

应为

if (l == r)
{
    sum[u] = b[l] % p;
    return ;
}

by __shadow__ @ 2023-10-15 14:46:41

@Digital_Sunrise 谢谢,是的


|