TLE 8 9 10求调

P3384 【模板】重链剖分/树链剖分

MINO1 @ 2024-01-06 19:46:44

#include<iostream>
#include<vector>

using namespace std;

int n, m, root, P;
long long point_val[100005];
vector<int> tree[100005];
int fa[100005];
int dep[100005];
int hev[100005];
int sz[100005];

void dfs1(int u, int pre) {
    fa[u] = pre;
    dep[u] = dep[pre] + 1;
    sz[u] = 1;
    int msz = 0; int hv = 0;
    for (auto v : tree[u]) {
        if (v == pre) continue;
        dfs1(v, u);
        if (sz[v] > msz) hv = v;
        sz[u] += sz[v];
    }
    hev[u] = hv;
}

int dfn[100005], rdfn[100005];
int dfn_id = 0; int top[100005];

void dfs2(int u, int topf) {
    dfn[++dfn_id] = u;
    rdfn[u] = dfn_id;
    top[u] = topf;
    if (!hev[u]) return;
    dfs2(hev[u], topf);
    for (auto v : tree[u]) {
        if (v == fa[u] || v == hev[u]) continue;
        dfs2(v, v);
    }
}

long long xds_val[400005];
long long lan[400005];

void create_xds(int p, int s, int t) {
    if (s == t) {
        xds_val[p] = (point_val[dfn[s]]) % P;
        return;
    }
    int m = (s + t) >> 1;
    create_xds(p * 2, s, m);
    create_xds(p * 2 + 1, m + 1, t);
    xds_val[p] = (xds_val[p * 2] + xds_val[p * 2 + 1]) % P;
}

void add2(int p, int l, int r, int s, int t, int val) {
    if (l <= s && r >= t) {
        lan[p] = (lan[p] + val) % P;
        xds_val[p] = (xds_val[p] + (val * (static_cast<long long>(t) - s + 1)) % P) % P;
        return;
    }
    int m = (s + t) >> 1;
    if (lan[p] && s != t) {
        lan[p * 2] = (lan[p * 2] + lan[p]) % P, lan[p * 2 + 1] = (lan[p * 2 + 1] + lan[p]) % P;
        xds_val[p * 2] = (xds_val[p * 2] + (lan[p] * (static_cast<long long>(m) - s + 1)) % P) % P, xds_val[p * 2 + 1] = (xds_val[p * 2 + 1] + (lan[p] * (static_cast<long long>(t) - m)) % P) % P;
        lan[p] = 0;
    }
    if (l <= m) add2(p * 2, l, r, s, m, val);
    if (r > m) add2(p * 2 + 1, l, r, m + 1, t, val);
    xds_val[p] = (xds_val[p * 2] + xds_val[p * 2 + 1]) % P;
}

long long sum2(int p, int l, int r, int s, int t) {
    if (l <= s && r >= t) {
        return xds_val[p] % P;
    }
    int m = (s + t) >> 1;
    if (lan[p] && s != t) {
        lan[p * 2] = (lan[p * 2] + lan[p]) % P, lan[p * 2 + 1] = (lan[p * 2 + 1] + lan[p]) % P;
        xds_val[p * 2] = (xds_val[p * 2] + (lan[p] * (static_cast<long long>(m) - s + 1)) % P) % P, xds_val[p * 2 + 1] = (xds_val[p * 2 + 1] + (lan[p] * (static_cast<long long>(t) - m)) % P) % P;
        lan[p] = 0;
    }
    long long sum = 0;
    if (l <= m) sum = (sum + sum2(p * 2, l, r, s, m)) % P;
    if (r > m) sum = (sum + sum2(p * 2 + 1, l, r, m + 1, t)) % P;
    xds_val[p] = (xds_val[p * 2] + xds_val[p * 2 + 1]) % P;
    return sum % P;
}

void add1(int x, int y, int z) {
    while (top[x] != top[y]) {
        if (dep[top[y]] > dep[top[x]]) swap(x, y);
        add2(1, rdfn[top[x]], rdfn[x], 1, n, z);
        x = fa[top[x]];
    }
    if (dep[y] > dep[x]) swap(x, y);
    add2(1, rdfn[y], rdfn[x], 1, n, z);
}

long long sum1(int x, int y) {
    long long sum = 0;
    while (top[x] != top[y]) {
        if (dep[top[y]] > dep[top[x]]) swap(x, y);
        sum = (sum + sum2(1, rdfn[top[x]], rdfn[x], 1, n)) % P;
        x = fa[top[x]];
    }
    if (dep[y] > dep[x]) swap(x, y);
    sum = (sum + sum2(1, rdfn[y], rdfn[x], 1, n)) % P;
    return sum % P;
}

void solve() {
    cin >> n >> m >> root >> P;
    for (int i = 1; i <= n; i++) {
        cin >> point_val[i];
        point_val[i] = point_val[i] % P;
    }
    for (int i = 1; i < n; i++) {
        int u, v; cin >> u >> v;
        tree[u].push_back(v);
        tree[v].push_back(u);
    }
    dfs1(root, 0);
    dfs2(root, root);
    create_xds(1, 1, n);
    while (m--) {
        int q; cin >> q;
        if (q == 3) {
            int x, z; cin >> x >> z;
            add2(1, rdfn[x], rdfn[x] + sz[x] - 1, 1, n, (z % P));
        }
        if (q == 4) {
            int x; cin >> x;
            cout << sum2(1, rdfn[x], rdfn[x] + sz[x] - 1, 1, n) % P << endl;
        }
        if (q == 1) {
            int x, y, z; cin >> x >> y >> z;
            add1(x, y, (z % P));
        }
        if (q == 2) {
            int x, y; cin >> x >> y;
            cout << sum1(x, y) % P << endl;
        }
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    solve();
}

by sunkuangzheng @ 2024-01-06 19:48:21

@MINO1 找重儿子的时候没有更新 msz

这就相当于直接取每个点的最后一个儿子当重儿子。


by MINO1 @ 2024-01-06 19:50:49

@sunkuangzheng 我靠,这里出问题了,感谢


by MINO1 @ 2024-01-06 19:52:25

@sunkuangzheng ac这么多,我还以为是死循环了...


|