xinhuo2005 @ 2024-07-13 15:49:54
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 2e5 + 10;
vector<int> e[maxn];
int a[maxn], w[maxn];
ll seg[maxn << 2], tag[maxn << 2];
int fa[maxn], siz[maxn], son[maxn], dep[maxn];
int dfn[maxn], top[maxn], rk[maxn];
int mod;
inline int read()
{
int x = 0, f = 1;
char c = getchar();
while (c < '0' || c > '9')
{
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9')
{
x = (x << 1) + (x << 3) + (c ^ '0');
c = getchar();
}
return x * f;
}
void add(int u, int v)
{
e[u].emplace_back(v);
e[v].emplace_back(u);
}
inline int ls(int p) {return p << 1;}
inline int rs(int p) {return p << 1 | 1;}
void push_up(int p)
{
seg[p] = (seg[ls(p)] + seg[rs(p)]) % mod;
}
void push_down(int l, int r, int p)
{
int mid = (l + r) >> 1;
seg[ls(p)] = (seg[ls(p)] + tag[p] * (mid - l + 1ll)) % mod;
seg[rs(p)] = (seg[rs(p)] + tag[p] * (r - mid)) % mod;
tag[ls(p)] = (tag[ls(p)] + tag[p]) % mod;
tag[rs(p)] = (tag[rs(p)] + tag[p]) % mod;
tag[p] = 0;
}
void bulid(int l, int r, int p)
{
if (l == r)
{
seg[p] = a[l] % mod; return;
}
int mid = (l + r) >> 1;
bulid(l, mid, ls(p)); bulid(mid + 1, r, rs(p));
push_up(p);
}
void update_add(int s, int t, int l, int r, int x, int p)
{
x %= mod;
if (s <= l && r <= t)
{
seg[p] = ((r - l + 1) * x + seg[p]) % mod;
tag[p] = (tag[p] + x) % mod;
return ;
}
push_down(l, r, p);
int mid = (l + r) >> 1;
if (s <= mid) update_add(s, t, l, mid, x, ls(p));
if (mid < t) update_add(s, t, mid + 1, r, x, rs(p));
push_up(p);
}
ll getSum(int s, int t, int l, int r, int p)
{
ll res = 0ll;
if (s <= l && r <= t) return seg[p] % mod;
push_down(l, r, p);
int mid = (l + r) >> 1;
if (s <= mid) res = (res + getSum(s, t, l, mid, ls(p))) % mod;
if (mid < t) res = (res + getSum(s, t, mid + 1, r, rs(p))) % mod;
return res;
}
void dfs1(int u, int fno)
{
siz[u] = 1;
fa[u] = fno;
son[u] = 0;
int res = 1;
for (auto v : e[u])
{
if (v == fno) continue;
dep[v] = dep[u] + 1;
dfs1(v, u);
if (son[u] == 0 || res < siz[v])
{
son[u] = v;
res = siz[v];
}
siz[u] += siz[v];
}
}
int n, m, root, cnt;
void dfs2(int u, int t)
{
top[u] = t;
cnt ++;
dfn[u] = cnt;
rk[cnt] = u;
a[cnt] = w[u];
if (son[u] == 0) return ;
dfs2(son[u], t);
for (auto v : e[u])
{
if (v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}
void opt1(int x, int y, int z)
{
while (top[x] != top[y])
{
if (dep[x] < dep[y]) swap(x, y);
update_add(dfn[top[x]], dfn[x], 1, n, z, 1);
x = fa[top[x]];
}
if (dep[x] < dep[y]) swap(x, y);
update_add(dfn[y], dfn[x], 1, n, z, 1);
}
ll opt2(int x, int y)
{
ll res = 0ll;
while (top[x] != top[y])
{
if (dep[x] < dep[y]) swap(x, y);
res = (res + getSum(dfn[top[x]], dfn[x], 1, n, 1)) % mod;
x = fa[top[x]];
}
if (dep[x] < dep[y]) swap(x, y);
res = (res + getSum(dfn[y], dfn[x], 1, n, 1)) % mod;
return res;
}
void opt3(int x, int z)
{
update_add(dfn[x], dfn[x] + siz[x] - 1, 1, n, z, 1);
}
ll opt4(int x)
{
return getSum(dfn[x], dfn[x] + siz[x] - 1, 1, n, 1);
}
int main()
{
n = read(), m = read(), root = read(), mod = read();
for (int i = 1; i <= n; i++) w[i] = read();
int u, v;
for (int i = 1; i < n; i++)
{
u = read(), v = read();
add(u, v);
}
dfs1(root, root);
dfs2(root, root);
bulid(1, n, 1);
int opt, x, y, z;
while (m--)
{
opt = read();
if (opt == 1)
{
x = read(), y = read(), z = read();
opt1(x, y, z);
}
else if (opt == 2)
{
x = read(), y = read();
printf("%lld\n", opt2(x, y));
}
else if (opt == 3)
{
x = read(), z = read();
opt3(x, z);
}
else
{
x = read();
printf("%lld\n", opt4(x));
}
}
return 0;
}
by wby_1234 @ 2024-07-13 16:22:43
将“ seg[p] = (seg[ls(p)] + seg[rs(p)]) % mod;” 改为“ seg[p] = (seg[ls(p)] + seg[rs(p)]) % mod+1;”即可。
by xinhuo2005 @ 2024-07-13 16:41:56
@wby_1234 大佬不对,改了还是没过
by wby_1234 @ 2024-07-13 17:50:15
我过了
by LWT223355 @ 2024-08-01 09:41:30
@xinhuo2005 教练~ stO Orz