MINO1 @ 2024-01-06 19:46:44
#include<iostream>
#include<vector>
using namespace std;
int n, m, root, P;
long long point_val[100005];
vector<int> tree[100005];
int fa[100005];
int dep[100005];
int hev[100005];
int sz[100005];
void dfs1(int u, int pre) {
fa[u] = pre;
dep[u] = dep[pre] + 1;
sz[u] = 1;
int msz = 0; int hv = 0;
for (auto v : tree[u]) {
if (v == pre) continue;
dfs1(v, u);
if (sz[v] > msz) hv = v;
sz[u] += sz[v];
}
hev[u] = hv;
}
int dfn[100005], rdfn[100005];
int dfn_id = 0; int top[100005];
void dfs2(int u, int topf) {
dfn[++dfn_id] = u;
rdfn[u] = dfn_id;
top[u] = topf;
if (!hev[u]) return;
dfs2(hev[u], topf);
for (auto v : tree[u]) {
if (v == fa[u] || v == hev[u]) continue;
dfs2(v, v);
}
}
long long xds_val[400005];
long long lan[400005];
void create_xds(int p, int s, int t) {
if (s == t) {
xds_val[p] = (point_val[dfn[s]]) % P;
return;
}
int m = (s + t) >> 1;
create_xds(p * 2, s, m);
create_xds(p * 2 + 1, m + 1, t);
xds_val[p] = (xds_val[p * 2] + xds_val[p * 2 + 1]) % P;
}
void add2(int p, int l, int r, int s, int t, int val) {
if (l <= s && r >= t) {
lan[p] = (lan[p] + val) % P;
xds_val[p] = (xds_val[p] + (val * (static_cast<long long>(t) - s + 1)) % P) % P;
return;
}
int m = (s + t) >> 1;
if (lan[p] && s != t) {
lan[p * 2] = (lan[p * 2] + lan[p]) % P, lan[p * 2 + 1] = (lan[p * 2 + 1] + lan[p]) % P;
xds_val[p * 2] = (xds_val[p * 2] + (lan[p] * (static_cast<long long>(m) - s + 1)) % P) % P, xds_val[p * 2 + 1] = (xds_val[p * 2 + 1] + (lan[p] * (static_cast<long long>(t) - m)) % P) % P;
lan[p] = 0;
}
if (l <= m) add2(p * 2, l, r, s, m, val);
if (r > m) add2(p * 2 + 1, l, r, m + 1, t, val);
xds_val[p] = (xds_val[p * 2] + xds_val[p * 2 + 1]) % P;
}
long long sum2(int p, int l, int r, int s, int t) {
if (l <= s && r >= t) {
return xds_val[p] % P;
}
int m = (s + t) >> 1;
if (lan[p] && s != t) {
lan[p * 2] = (lan[p * 2] + lan[p]) % P, lan[p * 2 + 1] = (lan[p * 2 + 1] + lan[p]) % P;
xds_val[p * 2] = (xds_val[p * 2] + (lan[p] * (static_cast<long long>(m) - s + 1)) % P) % P, xds_val[p * 2 + 1] = (xds_val[p * 2 + 1] + (lan[p] * (static_cast<long long>(t) - m)) % P) % P;
lan[p] = 0;
}
long long sum = 0;
if (l <= m) sum = (sum + sum2(p * 2, l, r, s, m)) % P;
if (r > m) sum = (sum + sum2(p * 2 + 1, l, r, m + 1, t)) % P;
xds_val[p] = (xds_val[p * 2] + xds_val[p * 2 + 1]) % P;
return sum % P;
}
void add1(int x, int y, int z) {
while (top[x] != top[y]) {
if (dep[top[y]] > dep[top[x]]) swap(x, y);
add2(1, rdfn[top[x]], rdfn[x], 1, n, z);
x = fa[top[x]];
}
if (dep[y] > dep[x]) swap(x, y);
add2(1, rdfn[y], rdfn[x], 1, n, z);
}
long long sum1(int x, int y) {
long long sum = 0;
while (top[x] != top[y]) {
if (dep[top[y]] > dep[top[x]]) swap(x, y);
sum = (sum + sum2(1, rdfn[top[x]], rdfn[x], 1, n)) % P;
x = fa[top[x]];
}
if (dep[y] > dep[x]) swap(x, y);
sum = (sum + sum2(1, rdfn[y], rdfn[x], 1, n)) % P;
return sum % P;
}
void solve() {
cin >> n >> m >> root >> P;
for (int i = 1; i <= n; i++) {
cin >> point_val[i];
point_val[i] = point_val[i] % P;
}
for (int i = 1; i < n; i++) {
int u, v; cin >> u >> v;
tree[u].push_back(v);
tree[v].push_back(u);
}
dfs1(root, 0);
dfs2(root, root);
create_xds(1, 1, n);
while (m--) {
int q; cin >> q;
if (q == 3) {
int x, z; cin >> x >> z;
add2(1, rdfn[x], rdfn[x] + sz[x] - 1, 1, n, (z % P));
}
if (q == 4) {
int x; cin >> x;
cout << sum2(1, rdfn[x], rdfn[x] + sz[x] - 1, 1, n) % P << endl;
}
if (q == 1) {
int x, y, z; cin >> x >> y >> z;
add1(x, y, (z % P));
}
if (q == 2) {
int x, y; cin >> x >> y;
cout << sum1(x, y) % P << endl;
}
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
solve();
}
by sunkuangzheng @ 2024-01-06 19:48:21
@MINO1 找重儿子的时候没有更新 msz
。
这就相当于直接取每个点的最后一个儿子当重儿子。
by MINO1 @ 2024-01-06 19:50:49
@sunkuangzheng 我靠,这里出问题了,感谢
by MINO1 @ 2024-01-06 19:52:25
@sunkuangzheng ac这么多,我还以为是死循环了...