蒟蒻刚学莫队,0pts求助

P4074 [WC2013] 糖果公园

shinzanmono @ 2022-10-11 12:48:58

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int sz = 1e5 + 10;
const int lgsz = __lg(sz) + 1;
struct edge {
    int nxt, to;
} graph[sz << 1];
int hpp, head[sz];
void addEdge(int from, int to) {
    graph[++hpp] = edge{head[from], to};
    head[from] = hpp;
}
int dpp, dfn[sz], dep[sz], f[lgsz][sz], first[sz], last[sz], li[sz << 1], x, n, m, q;
void dfs(int u, int fa) {
    dfn[u] = ++dpp, f[0][dpp] = fa, dep[u] = dep[fa] + 1;
    li[++x] = u, first[u] = x;
    for (int p = head[u]; p; p = graph[p].nxt) {
        int v = graph[p].to;
        if (v == fa) continue;
        dfs(v, u);
    }
    li[++x] = u, last[u] = x;
}
int depmin(int u, int v) {
    return dep[u] < dep[v] ? u : v;
}
void init() {
    for (int i = 1; i <= __lg(n); i++)
        for (int j = 1; j + (1 << i) - 1 <= n; j++)
            f[i][j] = depmin(f[i - 1][j], f[i - 1][j + (1 << i - 1)]);
}
int lca(int u, int v) {
    int du = dfn[u], dv = dfn[v];
    if (du > dv) swap(du, dv);
    int lg = __lg(dv - du);
    return depmin(f[lg][du + 1], f[lg][dv - (1 << lg) + 1]);
}
struct query {
    int l, r, id, blockl, blockr, st, tt;
    bool operator<(const query &a) const {
        if (blockl != a.blockl) return l < a.l;
        if (blockr != a.blockr) return r < a.r;
        return tt < a.tt;
    }
} que[sz];
struct Change {
    int pos;
    ll val;
} chan[sz];
ll ans[sz], t[sz], tim[sz], cnt, v[sz], w[sz], block[sz], arr[sz];
void add(int x) {
    if (tim[x]) cnt -= w[t[arr[x]]--] * v[arr[x]];
    else cnt += w[++t[arr[x]]] * v[arr[x]];
    tim[x] ^= 1;
}
void change(int tt, int cl, int cr, int st) {
    int pos = chan[tt].pos, val = chan[tt].val;
    if (pos >= cl && pos <= cr || pos == st)
        cnt -= w[t[arr[pos]]--] * v[arr[pos]];
        cnt += w[++t[val]] * v[val];
    swap(arr[pos], chan[tt].val);
}
int main() {
    ios::sync_with_stdio(false);
    cin >> n >> m >> q;
    for (int i = 1; i <= m; i++) cin >> v[i];
    for (int i = 1; i <= n; i++) cin >> w[i];
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        addEdge(u, v);
        addEdge(v, u);
    }
    for (int i = 1; i <= n; i++) cin >> arr[i];
    dfs(1, 0);
    init();
    int lim = cbrt(n * n);
    for (int i = 1; i <= n; i++)
        block[i] = (i - 1) / lim + 1;
    int mx = 0, cx = 0;
    while (q--) {
        int op, a, b;
        cin >> op >> a >> b;
        if (op == 1) {
            if (first[a] > first[b]) swap(a, b);
            if (lca(a, b) == a) {
                int l = first[a], r = first[b];
                que[++mx] = query{l, r, mx, block[l], block[r], cx, 0};
            }
            else {
                int l = last[a], r = first[b];
                que[++mx] = query{l, r, mx, block[l], block[r], cx, lca(a, b)};
            }
        }
        else chan[++cx] = Change{a, b};
    }
    sort(que + 1, que + mx + 1);
    int l = 1, r = 0, tt = 0;
    for (int i = 1; i <= mx; i++) {
        int le = que[i].l, re = que[i].r, te = que[i].tt, st = que[i].st;
        while (l > le) add(li[--l]);
        while (r < re) add(li[++r]);
        while (r > re) add(li[r--]);
        while (l < le) add(li[l++]);
        while (tt < te) change(++tt, le, re, st);
        while (tt > te) change(tt--, le, re, st);
        if (st) add(st);
        ans[que[i].id] = cnt;
        if (st) add(st);
    }
    for (int i = 1; i <= mx; i++)
        cout << ans[i] << endl;
    return 0;
}

|