蒟蒻重链剖分板子,27pts,求条

P3384 【模板】重链剖分/树链剖分

wang_shuang @ 2024-11-28 09:27:26

#include <bits/stdc++.h>

const int N = 1e6;
int P;

inline int max(int a, int b) { return a > b ? a: b; }
inline int min(int a, int b) { return a < b ? a: b; }

struct Way { int u, v; } way[N];
int hed[N];

int num[N];
int par[N];
int dep[N];
int siz[N];
int son[N];
int top[N];
int tse[N], ptr = 1;
int ttr[N];

inline int init(int p, int d)
{
    dep[p] = d; siz[p] = 1; int ma = -1;
    for (int i = hed[p]; way[i].u == p; i++)
    {
        if (way[i].v == par[p]) continue;
        par[way[i].v] = p;
        siz[p] += init(way[i].v, d + 1);
        if (siz[way[i].v] > ma) son[p] = way[i].v, ma = siz[way[i].v];
    }
    return siz[p];
}

inline void dfs(int p, int t)
{
    tse[p] = ptr;
    ttr[ptr++] = p; //printf("%d %d %d\n", p, ptr - 1, tse[p], ttr[ptr - 1]);
    top[p] = t;
    if (!son[p]) return;
    dfs(son[p], t);
    for (int i = hed[p]; way[i].u == p; i++)
    {
        if (way[i].v == par[p]) continue;
        if (way[i].v == son[p]) continue;
        dfs(way[i].v, way[i].v);
    }
    return;
}

// segtree

int seg[N];
int tag[N];

#define L (p << 1)
#define R (p << 1 | 1)
#define M (l + r >> 1)

inline void build(int p, int l, int r)
{
    if (l == r - 1) seg[p] = num[ttr[l]];
    else
    {
        build(L, l, M);
        build(R, M, r);
        seg[p] = seg[L] + seg[R];
    }
    return;
}

inline void update(int p, int l, int r)
{
    if (tag[p])
    {
        tag[p] %= P;
        seg[p] += tag[p] * (r - l);
        tag[L] += tag[p];
        tag[R] += tag[p];
        tag[p] = 0;
        seg[p] %= P;
    }
    return;
}

inline int add(int p, int l, int r, int x, int y, int k)
{
    update(p, l, r);
    if (l >= y || r <= x) return 0;
    if (l >= x && r <= y) return tag[p] += k;
    seg[p] += k * (min(y, r) - max(x, l));
    return add(L, l, M, x, y, k), add(R, M, r, x, y, k);
}

inline int find(int p, int l, int r, int x, int y)
{
    update(p, l, r);
    if (l >= y || r <= x) return 0;
    if (l >= x && r <= y) return seg[p];
    return (find(L, l, M, x, y) + find(R, M, r, x, y)) % P;
}

// segtree

inline void swap(int&a, int&b) { int c = a; a = b; b = c; return; }

inline int findway(int a, int b, const int n)
{
    int ans = 0;
    while (top[a] - top[b])
    {
        if (dep[a] < dep[b]) swap(a, b);
        ans = (ans + find(1, 1, n + 1, tse[top[a]], tse[a] + 1)) % P; a = par[top[a]];
//      printf("%d ", ans);
    }
    if (dep[a] > dep[b]) swap(a, b);
    return (ans + find(1, 1, n + 1, tse[a], tse[b] + 1)) % P;
}

inline void addway(int a, int b, int k, const int n)
{
    while (top[a] - top[b])
    {
        if (dep[a] < dep[b]) swap(a, b);
        add(1, 1, n + 1, tse[top[a]], tse[a] + 1, k); a = par[top[a]];
    }
    if (dep[a] > dep[b]) swap(a, b);
    add(1, 1, n + 1, tse[a], tse[b] + 1, k);
    return;
}

inline int findtre(int p, const int n)
{
    return find(1, 1, n + 1, tse[p], tse[p] + siz[p]);
}

inline void addtre(int p, int k, const int n)
{
    add(1, 1, n + 1, tse[p], tse[p] + siz[p], k);
//  printf("%d %d %d %d %d\n", p, tse[p], tse[p] + siz[p], k, find(1, 1, n + 1, tse[p], tse[p] + siz[p]));
    return;
}

inline bool cmp(Way a, Way b) { return a.u < b.u; }

int main()
{
//  freopen("a.in", "r", stdin);
//  freopen("a.out", "w", stdout);

    int n, m, r; scanf("%d %d %d %d", &n, &m, &r, &P); for (int i = 1; i <= n; i++) scanf("%d", num + i);

    for (int i = 1; i < n; i++)
    {
        scanf("%d %d", &way[i << 1].u, &way[i << 1].v);
        way[i << 1 | 1] = { way[i << 1].v, way[i << 1].u };
    }
    std::sort(way + 2, way + n * 2, cmp);
    for (int i = 2; i < n * 2; i++) if (!hed[way[i].u]) hed[way[i].u] = i;

    init(r, 1);
    dfs(r, r);
//  for (int i = 1; i <= n; i++) printf("%d ", par[i]); printf("\n");
//  for (int i = 1; i <= n; i++) printf("%d ", dep[i]); printf("\n");
//  for (int i = 1; i <= n; i++) printf("%d ", son[i]); printf("\n");
//  for (int i = 1; i <= n; i++) printf("%d ", siz[i]); printf("\n");
//  for (int i = 1; i <= n; i++) printf("%d ", top[i]); printf("\n");
//  for (int i = 1; i <= n; i++) printf("%d ", tse[i]); printf("\n");
//  for (int i = 1; i <= n; i++) printf("%d ", ttr[i]); printf("\n");
    build(1, 1, n + 1);
    while (m--)
    {
        int opt, a, b, c; scanf("%d %d", &opt, &a);
        if (opt == 1)
        {
            scanf("%d %d", &b, &c);
            addway(a, b, c, n);
        }
        if (opt == 2)
        {
            scanf("%d", &b);
            printf("%d\n", findway(a, b, n));
        }
        if (opt == 3)
        {
            scanf("%d", &b);
            addtre(a, b, n);
        }
        if (opt == 4)
        {
            printf("%d\n", findtre(a, n));
        }
//      for (int i = 1; i <= n; i++) printf("%d ", find(1, 1, n + 1, i, i + 1)); printf("\n");
    }

    return 0;
}

by chen_z @ 2024-11-28 10:10:27

@wang_shuang 你这么写线段树是历史遗留问题吗,好奇怪啊,从没见过谁线段树这么写的


by wang_shuang @ 2024-11-28 10:17:24

@chen_z 我学线段树是自己琢磨出来的,一开始就是用数组(可能是我不太喜欢用结构体)和左闭右开区间。


|