too young too simple 的问题

P3384 【模板】重链剖分/树链剖分

残阳如血 @ 2024-05-18 21:11:15

一个小问题。

对于下面的代码,为什么 op=4 时需要在 main() 中对线段树查询的结果再次取模,不是已经在 SegmentTree::query() 实现的时候已经取模了吗?

main() 函数中未取模只有 90 分。


by 残阳如血 @ 2024-05-18 21:11:34

#include <vector>
#include <iostream>
const int N = 1e5 + 10;
typedef long long lint;

std::vector<int> g[N];
int n, m, r, p, cnt, a[N];
int fa[N], dep[N], size[N], son[N], top[N], dfn[N], rnk[N];

void dfs1(int u, int f) {
    size[u] = 1, dep[u] = dep[fa[u] = f] + 1;
    for (auto v : g[u]) {
        if (v == f) continue;
        dfs1(v, u);
        size[u] += size[v];
        if (size[v] > size[son[u]]) son[u] = v;
    }
}

void dfs2(int u, int ftop) {
    dfn[u] = ++cnt, rnk[cnt] = u, top[u] = ftop;
    if (!son[u]) return ;
    dfs2(son[u], ftop); // 重儿子
    for (auto v : g[u]) {
        if (v == son[u] || v == fa[u]) continue;
        dfs2(v, v); // 轻儿子
    }
}

namespace SegmentTree {
    lint w[N << 2], lzy[N << 2];
    bool inRange(int l, int r, int L, int R) { return l >= L && r <= R; }
    bool outRange(int l, int r, int L, int R) { return l > R || r < L; }
    void maketag(int u, int l, int r, lint x) {
        (w[u] += (r - l + 1) * x) %= p;
        (lzy[u] += x) %= p;
    }
    void pushup(int u) { w[u] = w[u * 2] + w[u * 2 + 1]; }
    void pushdown(int u, int l, int r) {
        int mid = l + r >> 1;
        maketag(u * 2, l, mid, lzy[u]);
        maketag(u * 2 + 1, mid + 1, r, lzy[u]);
        lzy[u] = 0;
    }
    void build(int u, int l, int r) {
        if (l == r) {
            w[u] = a[rnk[l]];
            return ;
        }
        int mid = l + r >> 1;
        build(u * 2, l, mid);
        build(u * 2 + 1, mid + 1, r);
        pushup(u);
    }
    lint query(int u, int l, int r, int L, int R) {
        if (outRange(l, r, L, R)) return 0;
        if (inRange(l, r, L, R)) return w[u];
        pushdown(u, l, r);
        int mid = l + r >> 1;
        return query(u * 2, l, mid, L, R) + query(u * 2 + 1, mid + 1, r, L, R);
    }
    void update(int u, int l, int r, int L, int R, lint x) {
        if (outRange(l, r, L, R)) return ;
        if (inRange(l, r, L, R)) { maketag(u, l, r, x); return ; }
        pushdown(u, l, r);
        int mid = l + r >> 1;
        update(u * 2, l, mid, L, R, x);
        update(u * 2 + 1, mid + 1, r, L, R, x);
        pushup(u);
    }
}

lint query(int x, int y) {
    lint res = 0;
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) std::swap(x, y);
        // top[x] 更深,那么 x-top[x] 路径上不会有 LCA(x,y)
        (res += SegmentTree::query(1, 1, n, dfn[top[x]], dfn[x])) %= p;
        x = fa[top[x]];
    }
    return (res + SegmentTree::query(1, 1, n, std::min(dfn[x], dfn[y]), std::max(dfn[x], dfn[y]))) % p;
}

void update(int x, int y, int z) {
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) std::swap(x, y);
        SegmentTree::update(1, 1, n, dfn[top[x]], dfn[x], z);
        x = fa[top[x]];
    }
    SegmentTree::update(1, 1, n, std::min(dfn[x], dfn[y]), std::max(dfn[x], dfn[y]), z);
}

int main() {
    std::cin.tie(0)->sync_with_stdio(0);
    std::cin >> n >> m >> r >> p;
    for (int i = 1; i <= n; ++i) std::cin >> a[i];
    for (int i = 1, x, y; i < n; ++i) {
        std::cin >> x >> y;
        g[x].push_back(y);
        g[y].push_back(x);
    }
    dfs1(r, 0);
    dfs2(r, 0);
    SegmentTree::build(1, 1, n);
    for (int op, x, y, z; m; --m) {
        std::cin >> op >> x;
        if (op == 1) {
            std::cin >> y >> z;
            update(x, y, z);
        } else if (op == 2) {
            std::cin >> y;
            std::cout << query(x, y) << '\n';
        } else if (op == 3) {
            std::cin >> z;
            SegmentTree::update(1, 1, n, dfn[x], dfn[x] + size[x] - 1, z);
        } else {
            std::cout << SegmentTree::query(1, 1, n, dfn[x], dfn[x] + size[x] - 1) % p << '\n';
        }
    }
    return 0;
}

by Wind_Leaves_ShaDow @ 2024-05-18 21:18:03

return query(u * 2, l, mid, L, R) + query(u * 2 + 1, mid + 1, r, L, R);

这不是没取模吗(?


by 残阳如血 @ 2024-05-18 21:23:15

@Wind_Leaves_ShaDow 哦,这里复制错了,取模了也不行a


by 残阳如血 @ 2024-05-18 21:24:45

@Wind_Leaves_ShaDow

比如这份代码就 WA 90。

奇怪的是, query() 函数(非线段树内部的)就没有这种问题。

#include <vector>
#include <iostream>
const int N = 1e5 + 10;
typedef long long lint;

std::vector<int> g[N];
int n, m, r, p, cnt, a[N];
int fa[N], dep[N], size[N], son[N], top[N], dfn[N], rnk[N];

void dfs1(int u, int f) {
    size[u] = 1, dep[u] = dep[fa[u] = f] + 1;
    for (auto v : g[u]) {
        if (v == f) continue;
        dfs1(v, u);
        size[u] += size[v];
        if (size[v] > size[son[u]]) son[u] = v;
    }
}

void dfs2(int u, int ftop) {
    dfn[u] = ++cnt, rnk[cnt] = u, top[u] = ftop;
    if (!son[u]) return ;
    dfs2(son[u], ftop); // 重儿子
    for (auto v : g[u]) {
        if (v == son[u] || v == fa[u]) continue;
        dfs2(v, v); // 轻儿子
    }
}

namespace SegmentTree {
    lint w[N << 2], lzy[N << 2];
    bool inRange(int l, int r, int L, int R) { return l >= L && r <= R; }
    bool outRange(int l, int r, int L, int R) { return l > R || r < L; }
    void maketag(int u, int l, int r, lint x) {
        (w[u] += (r - l + 1) * x) %= p;
        (lzy[u] += x) %= p;
    }
    void pushup(int u) { w[u] = w[u * 2] + w[u * 2 + 1]; }
    void pushdown(int u, int l, int r) {
        int mid = l + r >> 1;
        maketag(u * 2, l, mid, lzy[u]);
        maketag(u * 2 + 1, mid + 1, r, lzy[u]);
        lzy[u] = 0;
    }
    void build(int u, int l, int r) {
        if (l == r) {
            w[u] = a[rnk[l]];
            return ;
        }
        int mid = l + r >> 1;
        build(u * 2, l, mid);
        build(u * 2 + 1, mid + 1, r);
        pushup(u);
    }
    lint query(int u, int l, int r, int L, int R) {
        if (outRange(l, r, L, R)) return 0;
        if (inRange(l, r, L, R)) return w[u];
        pushdown(u, l, r);
        int mid = l + r >> 1;
        return (query(u * 2, l, mid, L, R) + query(u * 2 + 1, mid + 1, r, L, R)) % p;
    }
    void update(int u, int l, int r, int L, int R, lint x) {
        if (outRange(l, r, L, R)) return ;
        if (inRange(l, r, L, R)) { maketag(u, l, r, x); return ; }
        pushdown(u, l, r);
        int mid = l + r >> 1;
        update(u * 2, l, mid, L, R, x);
        update(u * 2 + 1, mid + 1, r, L, R, x);
        pushup(u);
    }
}

lint query(int x, int y) {
    lint res = 0;
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) std::swap(x, y);
        // top[x] 更深,那么 x-top[x] 路径上不会有 LCA(x,y)
        (res += SegmentTree::query(1, 1, n, dfn[top[x]], dfn[x])) %= p;
        x = fa[top[x]];
    }
    return (res + SegmentTree::query(1, 1, n, std::min(dfn[x], dfn[y]), std::max(dfn[x], dfn[y]))) % p;
}

void update(int x, int y, int z) {
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) std::swap(x, y);
        SegmentTree::update(1, 1, n, dfn[top[x]], dfn[x], z);
        x = fa[top[x]];
    }
    SegmentTree::update(1, 1, n, std::min(dfn[x], dfn[y]), std::max(dfn[x], dfn[y]), z);
}

int main() {
    std::cin.tie(0)->sync_with_stdio(0);
    std::cin >> n >> m >> r >> p;
    for (int i = 1; i <= n; ++i) std::cin >> a[i];
    for (int i = 1, x, y; i < n; ++i) {
        std::cin >> x >> y;
        g[x].push_back(y);
        g[y].push_back(x);
    }
    dfs1(r, 0);
    dfs2(r, 0);
    SegmentTree::build(1, 1, n);
    for (int op, x, y, z; m; --m) {
        std::cin >> op >> x;
        if (op == 1) {
            std::cin >> y >> z;
            update(x, y, z);
        } else if (op == 2) {
            std::cin >> y;
            std::cout << query(x, y) << '\n';
        } else if (op == 3) {
            std::cin >> z;
            SegmentTree::update(1, 1, n, dfn[x], dfn[x] + size[x] - 1, z);
        } else {
            std::cout << SegmentTree::query(1, 1, n, dfn[x], dfn[x] + size[x] - 1) << '\n';
        }
    }
    return 0;
}

by Wind_Leaves_ShaDow @ 2024-05-18 21:24:52

找到了。

if (inRange(l, r, L, R)) return w[u];

这里可能大于模数,原因好像是 a 数组本身就有大于模数的,我用你的代码把 a 数组模了一遍就过了。


by 残阳如血 @ 2024-05-19 08:23:26

@Wind_Leaves_ShaDow 哦哦哦,谢谢


|