AC #2 #3 #11 28pts 求调玄关

P3384 【模板】重链剖分/树链剖分

dengzengxiao @ 2024-11-24 09:21:45

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e6 + 5;

int n, m, r, p, a[N];
int fa[N], dep[N], siz[N], son[N], tp[N], id[N], rnk[N], cnt;
int tot, bg[N], nxt[N << 1], to[N << 1];

struct Node {
    int l, r, sum, add;
} tree[N << 2];

void pushup(int pos) {
    tree[pos].sum = tree[pos << 1].sum + tree[pos << 1 | 1].sum;
    tree[pos].sum %= p;
}

void pushdown(int pos) {
    Node &root = tree[pos], &left = tree[pos << 1], &right = tree[pos << 1 | 1];
    if(root.add != 0) {
        left.add += root.add;
        left.sum += (left.r - left.l + 1) * root.add;
        left.add %= p; left.sum %= p;
        right.add += root.add;
        right.sum += (right.r - right.l + 1) * root.add;
        right.add %= p; right.sum %= p;
        root.add = 0;
    }
}

void build(int pos, int l, int r) {
    if(l == r) tree[pos] = {l, r, 0, 0};
    else {
        tree[pos].l = l; tree[pos].r = r;
        int mid = (l + r) >> 1;
        build(pos << 1, l, mid);
        build(pos << 1 | 1, mid + 1, r);
        pushup(pos);
    }
}

void modify(int pos, int l, int r, int v) {
    if(l <= tree[pos].l && tree[pos].r <= r) {
        tree[pos].add += v;
        tree[pos].sum += (tree[pos].r - tree[pos].l + 1) * v;
        tree[pos].add %= p; tree[pos].sum %= p;
    } else {
        int mid = (tree[pos].l + tree[pos].r) >> 1;
        pushdown(pos);
        if(l <= mid) modify(pos << 1, l, r, v);
        if(r > mid) modify(pos << 1 | 1, l, r, v);
        pushup(pos);
    }
}

int query(int pos, int l, int r) {
    if(l <= tree[pos].l && tree[pos].r <= r) return tree[pos].sum;
    else {
        int mid = (tree[pos].l + tree[pos].r) >> 1, res = 0;
        pushdown(pos);
        if(l <= mid) res += query(pos << 1, l, r);
        if(r > mid) res += query(pos << 1 | 1, l, r);
        return res % p;
    }
}

void addedge(int x, int y) {
    to[++tot] = y;
    nxt[tot] = bg[x];
    bg[x] = tot;
}

void dfs1(int pos, int last) {
    fa[pos] = last;
    dep[pos] = dep[last] + 1;
    siz[pos] = 1;

    for(int i = bg[pos]; i != 0; i = nxt[i]) {
        int e = to[i];
        if(e != last) {
            dfs1(e, pos);
            siz[pos] += siz[e];
            if(siz[e] > siz[son[pos]])
                son[pos] = e;
        }
    }
}

void dfs2(int pos, int top) {
    tp[pos] = top;
    id[++cnt] = pos;
    rnk[pos] = cnt;

    if(son[pos] == 0) return;
    dfs2(son[pos], top);
    for(int i = bg[pos]; i != 0; i = nxt[i]) {
        int e = to[i];
        if(e != fa[pos] && e != son[pos])
            dfs2(e, e);
    }
}

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);

    cin >> n >> m >> r >> p;
    for(int i = 1; i <= n; ++i) {
        cin >> a[i];
        a[i] %= p;
    }
    for(int i = 1; i < n; ++i) {
        int x, y; cin >> x >> y;
        addedge(x, y); addedge(y, x);
    }

    dfs1(r, r);
    dfs2(r, r);

    build(1, 1, cnt);
    for(int i = 1; i <= n; ++i)
        modify(1, rnk[i], rnk[i], a[i]);

    while(m--) {
        int opt;
        cin >> opt;

        if(opt == 1) {
            int x, y, z; cin >> x >> y >> z;
            z %= p;
            while(tp[x] != tp[y]) {
                if(dep[tp[x]] < dep[tp[y]]) swap(x, y);
                modify(1, rnk[tp[x]], tp[x], z);
                x = fa[tp[x]];
            }
            if(dep[x] > dep[y]) swap(x, y);
            modify(1, min(rnk[x], rnk[y]), max(rnk[x], rnk[y]), z);
        } else if(opt == 2) {
            int x, y; cin >> x >> y;
            int res = 0;
            while(tp[x] != tp[y]) {
                if(dep[tp[x]] < dep[tp[y]]) swap(x, y);
                res += query(1, rnk[tp[x]], rnk[x]);
                res %= p;
                x = fa[tp[x]];
            }
            if(dep[x] > dep[y]) swap(x, y);
            res += query(1, min(rnk[x], rnk[y]), max(rnk[x], rnk[y]));
            res %= p;
            cout << res << endl;
        } else if(opt == 3) {
            int x, z; cin >> x >> z;
            z %= p;
            int l = rnk[x], r = rnk[x] + siz[x] - 1;
            modify(1, l, r, z);
        } else {
            int x; cin >> x;
            int l = rnk[x], r = rnk[x] + siz[x] - 1;
            cout << query(1, l, r) % p << endl;
        }
    }
    return 0;
}

提交记录


|