wang_shuang @ 2024-11-28 09:27:26
#include <bits/stdc++.h>
const int N = 1e6;
int P;
inline int max(int a, int b) { return a > b ? a: b; }
inline int min(int a, int b) { return a < b ? a: b; }
struct Way { int u, v; } way[N];
int hed[N];
int num[N];
int par[N];
int dep[N];
int siz[N];
int son[N];
int top[N];
int tse[N], ptr = 1;
int ttr[N];
inline int init(int p, int d)
{
dep[p] = d; siz[p] = 1; int ma = -1;
for (int i = hed[p]; way[i].u == p; i++)
{
if (way[i].v == par[p]) continue;
par[way[i].v] = p;
siz[p] += init(way[i].v, d + 1);
if (siz[way[i].v] > ma) son[p] = way[i].v, ma = siz[way[i].v];
}
return siz[p];
}
inline void dfs(int p, int t)
{
tse[p] = ptr;
ttr[ptr++] = p; //printf("%d %d %d\n", p, ptr - 1, tse[p], ttr[ptr - 1]);
top[p] = t;
if (!son[p]) return;
dfs(son[p], t);
for (int i = hed[p]; way[i].u == p; i++)
{
if (way[i].v == par[p]) continue;
if (way[i].v == son[p]) continue;
dfs(way[i].v, way[i].v);
}
return;
}
// segtree
int seg[N];
int tag[N];
#define L (p << 1)
#define R (p << 1 | 1)
#define M (l + r >> 1)
inline void build(int p, int l, int r)
{
if (l == r - 1) seg[p] = num[ttr[l]];
else
{
build(L, l, M);
build(R, M, r);
seg[p] = seg[L] + seg[R];
}
return;
}
inline void update(int p, int l, int r)
{
if (tag[p])
{
tag[p] %= P;
seg[p] += tag[p] * (r - l);
tag[L] += tag[p];
tag[R] += tag[p];
tag[p] = 0;
seg[p] %= P;
}
return;
}
inline int add(int p, int l, int r, int x, int y, int k)
{
update(p, l, r);
if (l >= y || r <= x) return 0;
if (l >= x && r <= y) return tag[p] += k;
seg[p] += k * (min(y, r) - max(x, l));
return add(L, l, M, x, y, k), add(R, M, r, x, y, k);
}
inline int find(int p, int l, int r, int x, int y)
{
update(p, l, r);
if (l >= y || r <= x) return 0;
if (l >= x && r <= y) return seg[p];
return (find(L, l, M, x, y) + find(R, M, r, x, y)) % P;
}
// segtree
inline void swap(int&a, int&b) { int c = a; a = b; b = c; return; }
inline int findway(int a, int b, const int n)
{
int ans = 0;
while (top[a] - top[b])
{
if (dep[a] < dep[b]) swap(a, b);
ans = (ans + find(1, 1, n + 1, tse[top[a]], tse[a] + 1)) % P; a = par[top[a]];
// printf("%d ", ans);
}
if (dep[a] > dep[b]) swap(a, b);
return (ans + find(1, 1, n + 1, tse[a], tse[b] + 1)) % P;
}
inline void addway(int a, int b, int k, const int n)
{
while (top[a] - top[b])
{
if (dep[a] < dep[b]) swap(a, b);
add(1, 1, n + 1, tse[top[a]], tse[a] + 1, k); a = par[top[a]];
}
if (dep[a] > dep[b]) swap(a, b);
add(1, 1, n + 1, tse[a], tse[b] + 1, k);
return;
}
inline int findtre(int p, const int n)
{
return find(1, 1, n + 1, tse[p], tse[p] + siz[p]);
}
inline void addtre(int p, int k, const int n)
{
add(1, 1, n + 1, tse[p], tse[p] + siz[p], k);
// printf("%d %d %d %d %d\n", p, tse[p], tse[p] + siz[p], k, find(1, 1, n + 1, tse[p], tse[p] + siz[p]));
return;
}
inline bool cmp(Way a, Way b) { return a.u < b.u; }
int main()
{
// freopen("a.in", "r", stdin);
// freopen("a.out", "w", stdout);
int n, m, r; scanf("%d %d %d %d", &n, &m, &r, &P); for (int i = 1; i <= n; i++) scanf("%d", num + i);
for (int i = 1; i < n; i++)
{
scanf("%d %d", &way[i << 1].u, &way[i << 1].v);
way[i << 1 | 1] = { way[i << 1].v, way[i << 1].u };
}
std::sort(way + 2, way + n * 2, cmp);
for (int i = 2; i < n * 2; i++) if (!hed[way[i].u]) hed[way[i].u] = i;
init(r, 1);
dfs(r, r);
// for (int i = 1; i <= n; i++) printf("%d ", par[i]); printf("\n");
// for (int i = 1; i <= n; i++) printf("%d ", dep[i]); printf("\n");
// for (int i = 1; i <= n; i++) printf("%d ", son[i]); printf("\n");
// for (int i = 1; i <= n; i++) printf("%d ", siz[i]); printf("\n");
// for (int i = 1; i <= n; i++) printf("%d ", top[i]); printf("\n");
// for (int i = 1; i <= n; i++) printf("%d ", tse[i]); printf("\n");
// for (int i = 1; i <= n; i++) printf("%d ", ttr[i]); printf("\n");
build(1, 1, n + 1);
while (m--)
{
int opt, a, b, c; scanf("%d %d", &opt, &a);
if (opt == 1)
{
scanf("%d %d", &b, &c);
addway(a, b, c, n);
}
if (opt == 2)
{
scanf("%d", &b);
printf("%d\n", findway(a, b, n));
}
if (opt == 3)
{
scanf("%d", &b);
addtre(a, b, n);
}
if (opt == 4)
{
printf("%d\n", findtre(a, n));
}
// for (int i = 1; i <= n; i++) printf("%d ", find(1, 1, n + 1, i, i + 1)); printf("\n");
}
return 0;
}
by chen_z @ 2024-11-28 10:10:27
@wang_shuang 你这么写线段树是历史遗留问题吗,好奇怪啊,从没见过谁线段树这么写的
by wang_shuang @ 2024-11-28 10:17:24
@chen_z 我学线段树是自己琢磨出来的,一开始就是用数组(可能是我不太喜欢用结构体)和左闭右开区间。