萌新刚学OI,树上带修莫队RE*8,求助

P4074 [WC2013] 糖果公园

shinzanmono @ 2022-10-12 22:19:37

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int sz = 1e5 + 10;
const int lgsz = __lg(sz) + 1;
struct edge {
    int nxt, to;
} graph[sz << 1];
int hpp, head[sz];
void addEdge(int from, int to) {
    graph[++hpp] = edge{head[from], to};
    head[from] = hpp;
}
int dpp, dfn[sz], dep[sz], f[lgsz][sz], first[sz], last[sz], li[sz << 1], x;
void dfs(int u, int fa) {
    dfn[u] = ++dpp, f[0][dpp] = fa, dep[u] = dep[fa] + 1;
    li[++x] = u, first[u] = x;
    for (int p = head[u]; p; p = graph[p].nxt) {
        int v = graph[p].to;
        if (v == fa) continue;
        dfs(v, u);
    }
    li[++x] = u, last[u] = x;
}
int depmin(int u, int v) {
    return dep[u] < dep[v] ? u : v;
}
int lca(int u, int v) {
    int du = dfn[u], dv = dfn[v];
    if (du > dv) swap(du, dv);
    int lg = __lg(dv - du);
    return depmin(f[lg][du + 1], f[lg][dv - (1 << lg) + 1]);
}
struct query {
    int l, r, id, blockl, blockr, st, tt;
    bool operator<(const query &a) const {
        if (blockl != a.blockl) return blockl < a.blockl;
        if (blockr != a.blockr) return blockr < a.blockr;
        return tt < a.tt;
    }
} que[sz];
struct Change {
    int pos, val;
} chan[sz];
ll ans[sz], cnt, v[sz], w[sz];
int arr[sz], t[sz], tim[sz], block[sz];
void add(int x) {
    if (tim[x]) cnt -= w[t[arr[x]]--] * v[arr[x]];
    else cnt += w[++t[arr[x]]] * v[arr[x]];
    tim[x] ^= 1;
}
void change(int tt) {
    if (tim[chan[tt].pos]) {
        cnt -= w[t[arr[chan[tt].pos]]--] * v[arr[chan[tt].pos]];
        cnt += w[++t[chan[tt].val]] * v[chan[tt].val];
    }
    swap(arr[chan[tt].pos], chan[tt].val);
}
int main() {
    ios::sync_with_stdio(false);
    int n, m, q;
    cin >> n >> m >> q;
    for (int i = 1; i <= m; i++) cin >> v[i];
    for (int i = 1; i <= n; i++) cin >> w[i];
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }
    for (int i = 1; i <= n; i++) cin >> arr[i];
    dfs(1, 0);
    for (int i = 1; i <= __lg(n); i++)
        for (int j = 1; j + (1 << i) - 1 <= n; j++)
            f[i][j] = depmin(f[i - 1][j], f[i - 1][j + (1 << i - 1)]);
    for (int i = 1; i <= n; i++)
        block[i] = (i - 1) / sqrt(n) + 1;
    int mx = 0, cx = 0;
    for (int i = 1; i <= q; i++) {
        int op, a, b, l, r;
        cin >> op >> a >> b;
        if (op == 1) {
            if (first[a] > first[b]) swap(a, b);
            if (lca(a, b) == a) {
                l = first[a], r = first[b];
                que[++mx] = query{l, r, mx, block[l], block[r], 0, cx};
            }
            else {
                l = last[a], r = first[b];
                que[++mx] = query{l, r, mx, block[l], block[r], lca(a, b), cx};
            }
        }
        else chan[++cx] = Change{a, b};
    }
    sort(que + 1, que + mx + 1);
    int l = 1, r = 0, tt = 0;
    for (int i = 1; i <= mx; i++) {
        int le = que[i].l, re = que[i].r, te = que[i].tt, st = que[i].st;
        while (l < le) add(li[l++]);
        while (l > le) add(li[--l]);
        while (r < re) add(li[++r]);
        while (r > re) add(li[r--]);
        if (st) add(st);
        while (tt < te) change(++tt);
        while (tt > te) change(tt--);
        ans[que[i].id] = cnt;
        if (st) add(st);
    }
    for (int i = 1; i <= mx; i++)
        cout << ans[i] << "\n";
    return 0;
}

by reveal @ 2022-10-12 22:36:37

建议查看 普通莫队算法

这个 l++,--l,++r,r-- 的顺序是错误的


by shinzanmono @ 2022-10-12 22:42:16

@reveal 感谢


|