树上带修莫队居然会TLE?

P4074 [WC2013] 糖果公园

shinzanmono @ 2022-10-13 10:52:55

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int sz = 2e5 + 10;
const int lgsz = __lg(sz) + 2;
inline ll read() {
    ll x = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-') f = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = (x << 3) + (x << 1) + ch - '0';
        ch = getchar();
    }
    return x * f;
}
struct edge {
    int nxt, to;
} graph[sz << 1];
int hpp, head[sz];
void addEdge(int from, int to) {
    graph[++hpp] = edge{head[from], to};
    head[from] = hpp;
}
int dpp, dfn[sz], dep[sz], f[lgsz][sz], first[sz], last[sz], li[sz << 1], x;
void dfs(int u, int fa) {
    dfn[u] = ++dpp, f[0][dpp] = fa, dep[u] = dep[fa] + 1;
    li[++x] = u, first[u] = x;
    for (int p = head[u]; p; p = graph[p].nxt) {
        int v = graph[p].to;
        if (v != fa) dfs(v, u);
    }
    li[++x] = u, last[u] = x;
}
int depmin(int u, int v) {
    return dep[u] > dep[v] ? v : u;
}
int lca(int u, int v) {
    if (u == v) return u;
    int du = dfn[u], dv = dfn[v];
    if (du > dv) swap(du, dv);
    int lg = __lg(dv - du);
    return depmin(f[lg][du + 1], f[lg][dv - (1 << lg) + 1]);
}
struct query {
    int l, r, id, blockl, blockr, st, t;
    bool operator<(const query &a) const {
        if (blockl != a.blockl) return l < a.l;
        if (blockr != a.blockr) {
            if (blockl & 1) return r < a.r;
            return r > a.r;
        }
        if (blockr & 1) return t < a.t;
        return t > a.t;
    }
} que[sz];
struct Change {
    int pos, val;
} chan[sz];
ll ans[sz], cnt, v[sz], w[sz];
int arr[sz], t[sz], tim[sz], block[sz];
void add(int x) {
    if (tim[x]) cnt -= w[t[arr[x]]--] * v[arr[x]];
    else cnt += w[++t[arr[x]]] * v[arr[x]];
    tim[x] ^= 1;
}
void change(int tt) {
    if (tim[chan[tt].pos]) {
        cnt -= w[t[arr[chan[tt].pos]]--] * v[arr[chan[tt].pos]];
        cnt += w[++t[chan[tt].val]] * v[chan[tt].val];
    }
    swap(arr[chan[tt].pos], chan[tt].val);
}
int main() {
    int n = read(), m = read(), q = read();
    for (int i = 1; i <= m; i++) v[i] = read();
    for (int i = 1; i <= n; i++) w[i] = read();
    for (int i = 1; i < n; i++) {
        int u = read(), v = read();
        addEdge(u, v);
        addEdge(v, u);
    }
    for (int i = 1; i <= n; i++) arr[i] = read();
    dfs(1, 0);
    for (int i = 1; i <= __lg(n); i++)
        for (int j = 1; j + (1 << i) - 1 <= n; j++)
            f[i][j] = depmin(f[i - 1][j], f[i - 1][j + (1 << i - 1)]);
    for (int i = 1; i <= n; i++)
        block[i] = (i - 1) / cbrt(n * n) + 1;
    int mx = 0, cx = 0;
    for (int i = 1; i <= q; i++) {
        int op = read(), a = read(), b = read(), l, r;
        if (op == 1) {
            if (first[a] > first[b]) swap(a, b);
            if (lca(a, b) == a) {
                l = first[a], r = first[b];
                que[++mx] = query{l, r, mx, block[l], block[r], 0, cx};
            }
            else {
                l = last[a], r = first[b];
                que[++mx] = query{l, r, mx, block[l], block[r], lca(a, b), cx};
            }
        }
        else chan[++cx] = Change{a, b};
    }
    sort(que + 1, que + mx + 1);
    int l = 1, r = 0, tt = 0;
    for (int i = 1; i <= mx; i++) {
        int le = que[i].l, re = que[i].r, te = que[i].t, st = que[i].st;
        while (l > le) add(li[--l]);
        while (r < re) add(li[++r]);
        while (l < le) add(li[l++]);
        while (r > re) add(li[r--]);
        if (st) add(st);
        while (tt < te) change(++tt);
        while (tt > te) change(tt--);
        ans[que[i].id] = cnt;
        if (st) add(st);
    }
    for (int i = 1; i <= mx; i++)
        printf("%lld\n", ans[i]);
    return 0;
}

|