Pentiment @ 2023-12-19 14:47:21
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 100005;
int n, q, r, mod, cnt;
vector<int> G[MAXN];
struct node {
vector<int> ch;
int fa, w, dep, siz, son, top, dfn, low;
} a[MAXN];
void dfs1(int u, int fa, int dep) {
a[u].fa = fa, a[u].dep = dep, a[u].siz = 1;
for (auto v : G[u]) {
if (v != fa) {
a[u].ch.push_back(v);
dfs1(v, u, dep + 1);
a[u].siz += a[v].siz;
}
}
}
void dfs2(int u, int rt) {
a[u].dfn = a[u].low = ++cnt, a[u].top = rt;
for (auto v : a[u].ch) {
if (a[v].siz > a[a[u].son].siz) a[u].son = v;
}
if (!a[u].son) return;
dfs2(a[u].son, rt);
for (auto v : a[u].ch) {
a[u].low = max(a[u].low, a[v].low);
if (v == a[u].son) continue;
dfs2(v, v);
}
}
int c[MAXN * 4], m[MAXN * 4];
inline void pushdown(int l, int r, int mid, int p) {
if (m[p]) {
c[p * 2] = (c[p * 2] + (long long)(mid - l + 1) * m[p]) % mod, m[p * 2] = ((long long)m[p * 2] + m[p]) % mod;
c[p * 2 + 1] = (c[p * 2 + 1] + (long long)(r - mid) * m[p]) % mod, m[p * 2 + 1] = ((long long)m[p * 2 + 1] + m[p]) % mod;
m[p] = 0;
}
}
void update(int s, int t, int v, int l, int r, int p) {
if (s <= l && r <= t) {
c[p] = (c[p] + (long long)(r - l + 1) * v) % mod, m[p] = ((long long)m[p] + v) % mod;
return;
}
int mid = (l + r) / 2;
if (l != r) pushdown(l, r, mid, p);
if (s <= mid) update(s, t, v, l, mid, p * 2);
if (t > mid) update(s, t, v, mid + 1, r, p * 2 + 1);
c[p] = ((long long)c[p * 2] + c[p * 2 + 1]) % mod;
}
int getsum(int s, int t, int l, int r, int p) {
if (s <= l && r <= t) return c[p];
int mid = (l + r) / 2, ans = 0;
pushdown(l, r, mid, p);
if (s <= mid) ans = ((long long)ans + getsum(s, t, l, mid, p * 2)) % mod;
if (t > mid) ans = ((long long)ans + getsum(s, t, mid + 1, r, p * 2 + 1)) % mod;
return ans;
}
void upd(int u, int v, int w) {
while (a[u].top != a[v].top) {
if (a[a[u].top].dep > a[a[v].top].dep) {
update(a[a[u].top].dfn, a[u].dfn, w, 1, n, 1);
u = a[a[u].top].fa;
} else {
update(a[a[v].top].dfn, a[v].dfn, w, 1, n, 1);
v = a[a[v].top].fa;
}
}
if (a[u].dep > a[v].dep) swap(u, v);
update(a[u].dfn, a[v].dfn, w, 1, n, 1);
}
int qsum(int u, int v) {
int ans = 0;
while (a[u].top != a[v].top) {
if (a[a[u].top].dep > a[a[v].top].dep) {
ans = ((long long)ans + getsum(a[a[u].top].dfn, a[u].dfn, 1, n, 1)) % mod;
u = a[a[u].top].fa;
} else {
ans = ((long long)ans + getsum(a[a[v].top].dfn, a[v].dfn, 1, n, 1)) % mod;
v = a[a[v].top].fa;
}
}
if (a[u].dep > a[v].dep) swap(u, v);
ans = ((long long)ans + getsum(a[u].dfn, a[v].dfn, 1, n, 1)) % mod;
return ans;
}
inline void updtree(int u, int w) {
update(a[u].dfn, a[u].low, w, 1, n, 1);
}
inline int qsumtree(int u) {
return getsum(a[u].dfn, a[u].low, 1, n, 1);
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
cin >> n >> q >> r >> mod;
for (int i = 1; i <= n; i++) cin >> a[i].w;
for (int i = 1, u, v; i < n; i++) {
cin >> u >> v;
G[u].push_back(v), G[v].push_back(u);
}
dfs1(r, r, 1);
dfs2(r, r);
for (int i = 1; i <= n; i++) update(a[i].dfn, a[i].dfn, a[i].w, 1, n, 1);
int op, x, y, z;
while (q--) {
cin >> op >> x;
switch (op) {
case 1:
cin >> y >> z;
upd(x, y, z);
break;
case 2:
cin >> y;
cout << qsum(x, y) << endl;
break;
case 3:
cin >> z;
updtree(x, z);
break;
case 4:
cout << qsumtree(x) << endl;
break;
}
}
}
by __Chx__ @ 2023-12-22 08:48:51
@Run_Time_Error
28行:
a[u].low = max(a[u].low, a[v].low);
应放在 dfs(v,v)
之后,否则a[v].low
返回值为 0。
26行:便利重儿子后也应转移 low
改完这些小问题就能顺利AC了。
by Pentiment @ 2023-12-22 18:13:55
@Chx thx