萌新第一次写树剖写挂了求调

P3384 【模板】重链剖分/树链剖分

hys52 @ 2023-10-09 09:46:13

只A两个点,且不过样例

#include <bits/stdc++.h>
using namespace std;

//前置概念:
//重儿子:对于每一个非叶子结点,其子结点为根的子树最大的那个子结点就是它的重儿子
//轻儿子:对于每一个非叶子结点,其所有非重儿子的子结点都是它的轻儿子(根结点也是轻儿子)
//重边:父结点与其重儿子的连边
//轻边:除重边之外的边
//重链:相邻重边相连,连接了若干个重儿子的链
//每一条重链应以一个轻儿子为最顶端结点
//是轻儿子的叶子结点自成一条重链

#define ls (p << 1)
#define rs (p << 1 | 1)
const int N = 1e5 + 5;
int n, m, root, mod;
vector<int> t[N];
int dep[N], fa[N], siz[N], son[N], top[N], id[N], w[N], cnt, wt[N];
//dep:每个结点的深度 fa:每个结点的父结点编号 siz:当前结点为根的子树大小 son:每个结点的重儿子编号
//top:每个结点所在重链的最顶端结点编号 id:把所有重链相接后每个结点的新编号 w:原来每个结点的权值
//cnt:辅助累加得到当前结点的新编号 wt:新编号对应的权值
struct seg {
    int l, r, v, add;
} st[N << 2];  //线段树维护wt,即剖出来的重链组成的序列

void dfs1(int u, int f, int depth) {
    dep[u] = depth;
    fa[u] = f;
    siz[u] = 1;
    int maxson = -1;  //打擂台计算结点u的子结点中最大的siz值,以确定哪个是重儿子
    for (auto v : t[u]) {
        if (v == f) continue;
        dfs1(v, u, depth + 1);
        siz[u] += siz[v];
        if (siz[v] > maxson) maxson = siz[v], son[u] = v;
    }
}

void dfs2(int u, int ttop) {  //ttop是当前结点所在重链的最顶端结点编号
    id[u] = ++cnt;
    wt[cnt] = w[u];
    top[u] = ttop;
    if (!son[u]) return;  //没有重儿子,说明是叶子结点
    dfs2(son[u], ttop);
    for (auto v : t[u]) {
        if (v == fa[u] || v == son[u]) continue;
        dfs2(v, v);  //轻儿子成为新的重链的起点
    }
}

void pushup(int p) {
    st[p].v = (st[ls].v + st[rs].v) % mod;
}

void build(int p, int l, int r) {
    st[p].l = l, st[p].r = r;
    if (l == r) {
        st[p].v = wt[l] % mod;
        return;
    }
    int mid = (l + r) >> 1;
    build(ls, l, mid);
    build(rs, mid + 1, r);
    pushup(p);
}

void pushdown(int p) {
    if (!st[p].add) return;
    st[ls].add = (st[ls].add + st[p].add) % mod;
    st[rs].add = (st[rs].add + st[p].add) % mod;
    st[ls].v = (st[ls].v + st[p].v * (st[ls].r - st[ls].l + 1) % mod) % mod;
    st[rs].v = (st[rs].v + st[p].v * (st[rs].r - st[rs].l + 1) % mod) % mod;
    st[p].add = 0;
}

void update(int p, int l, int r, int x) {
    int L = st[p].l, R = st[p].r;
    if (l <= L && r >= R) {
        st[p].v += x * (R - L + 1);
        st[p].add += x;
        return;
    }
    pushdown(p);
    int mid = (L + R) >> 1;
    if (l <= mid) update(ls, l, r, x);
    if (r > mid) update(rs, l, r, x);
    pushup(p);
}

int query(int p, int l, int r) {
    int L = st[p].l, R = st[p].r;
    if (l <= L && r >= R) return st[p].v;
    pushdown(p);
    int mid = (L + R) >> 1, res = 0;
    if (l <= mid) res = (res + query(ls, l, r)) % mod;
    if (r > mid) res = (res + query(rs, l, r)) % mod;
    return res;
}

void update1(int u, int v, int x) {  //将u到v最短路径上每个结点权值都加x
    //与query1类似,只是操作不同
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        update(1, id[top[u]], id[u], x);
        u = fa[top[u]];
    }
    if (dep[u] > dep[v]) swap(u, v);
    update(1, id[u], id[v], x);
}

int query1(int u, int v) {  //u到v最短路径所有结点权值和
    int sum = 0;
    while (top[u] != top[v]) {  //两个点不在同一条链上
        if (dep[top[u]] < dep[top[v]]) swap(u, v);  //使u点所在链顶端的深度大于u的
        sum = (sum + query(1, id[top[u]], id[u])) % mod;  //结果加上这条链的顶端到u点的点权和
        u = fa[top[u]];  //u跳到其所在链顶端结点的更上一个点,即跳到上面的链末尾
    }
    //两个点到同一条链上了
    if (dep[u] > dep[v]) swap(u, v);  //使v点所在深度更深,这样从u到v就是链的一部分
    sum = (sum + query(1, id[u], id[v])) % mod;  //结果加上链上u到v这一部分的点权和
    return sum;
}

void update2(int u, int x) {  //以u为根的子树中每个结点加x
    //siz[u]是以u为根的子树大小,那么id[u]..id[u] + siz[u] - 1就是树剖后序列里子树的区间
    update(1, id[u], id[u] + siz[u] - 1, x);
}

int query2(int u) {  //以u为根的子树中结点的权值和
    return query(1, id[u], id[u] + siz[u] - 1);  //同上
}

int main() {
    scanf("%d%d%d%d", &n, &m, &root, &mod);
    for (int i = 1; i <= n; ++i) scanf("%d", w + i);
    for (int i = 1, u, v; i < n; ++i) {
        scanf("%d%d", &u, &v);
        t[u].push_back(v);
        t[v].push_back(u);
    }
    dfs1(root, 0, 1);
    dfs2(root, root);
    build(1, 1, n);
    while (m--) {
        int op, u, v, x;
        scanf("%d%d", &op, &u);
        if (op == 1) {
            scanf("%d%d", &v, &x);
            update1(u, v, x);
        } else if (op == 2) {
            scanf("%d", &v);
            printf("%d\n", query1(u, v));
        } else if (op == 3) {
            scanf("%d", &x);
            update2(u, x);
        } else printf("%d\n", query2(u));
    }
    return 0;
}

求大佬帮调,谢谢


by QCurium @ 2023-10-09 09:52:37

你的 pushdown 写挂了,更新子节点值的时候,要用父节点的tag更新,而不是value


by QCurium @ 2023-10-09 10:01:19

你的程序里是把 v 改成 add


by hys52 @ 2023-10-09 10:06:56

@quchenming 过了,感谢


by QCurium @ 2023-10-09 10:16:45

@hys52 没事


|