xiaozhuo @ 2023-09-19 12:40:43
#include<bits/stdc++.h>
using namespace std;
#define ll long long
int n, root, m, p, top[100010], dep[100010], size[100010], son[100010], fa[100010], rev[100010], id[100010], w[100010];
int head[100010], cnt, tot;
struct Tree
{
ll tag, sum;
}tr[400010];
struct Node
{
int to, next;
}e[200010];
void add(int u, int v)
{
e[++cnt].to = v, e[cnt].next = head[u], head[u] = cnt;
}
void dfs1(int u, int f)
{
size[u] = 1;
for(int i = head[u];i;i = e[i].next)
{
int y = e[i].to;
if(y == f) continue;
dep[y] = dep[u] + 1;
fa[y] = u;
dfs1(y, u);
size[u] += size[y];
if(size[y] > size[son[u]]) son[u] = y;
}
}
void dfs2(int u, int t)
{
top[u] = t;
id[u] = ++tot;
rev[tot] = u;
if(!son[u]) return;
dfs2(son[u], t);
for(int i = head[u];i;i = e[i].next)
{
int y = e[i].to;
if(y == fa[u] || y == son[u]) continue;
dfs2(y, y);
}
}
void pushup(int rt)
{
tr[rt].sum = (tr[rt * 2].sum + tr[rt * 2 + 1].sum % p) % p;
}
void pushdown(int rt, int len)
{
if(tr[rt].tag)
{
tr[rt * 2].sum = (tr[rt * 2].sum + tr[rt].tag % p * (len - (len >> 1))) % p;
tr[rt * 2 + 1].sum = (tr[rt * 2 + 1].sum + tr[rt].tag % p * (len >> 1)) % p;
tr[rt * 2].tag += tr[rt].tag;
tr[rt * 2 + 1].tag += tr[rt].tag;
tr[rt].tag = 0;
}
}
void build(int l, int r, int rt)
{
if(l >= r)
{
tr[rt].sum = w[rev[l]] % p;
return;
}
int mid = (l + r) >> 1;
build(l, mid, rt * 2);
build(mid + 1, r, rt * 2 + 1);
pushup(rt);
}
ll query(int L, int R, int l, int r, int rt)
{
if(L <= l && R >= r)
return tr[rt].sum % p;
pushdown(rt, r - l + 1);
int mid = (l + r) >> 1;
ll res = 0;
if(R > mid) res = (res + query(L, R, mid + 1, r, rt * 2 + 1) % p) % p;
if(L <= mid) res = (res + query(L, R, l, mid, rt * 2) % p) % p;
return res % p;
}
void update(int L, int R, int l, int r, int rt, int k)
{
if(L <= l && R >= r)
{
tr[rt].sum = (tr[rt].sum + k * (r - l + 1)) % p;
tr[rt].tag += k;
return;
}
int mid = (l + r) >> 1;
if(R > mid) update(L, R, mid + 1, r, rt * 2 + 1, k);
if(L <= mid) update(L, R, l, mid, rt * 2, k);
pushup(rt);
}
void changel(int x, int y, int k)
{
k %= p;
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]]) swap(x, y);
update(id[top[x]], id[x], 1, n, 1, k);
x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x, y);
update(id[x], id[y], 1, n, 1, k);
}
ll ql(int x, int y)
{
ll sum = 0;
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]]) swap(x, y);
sum = (sum + query(id[top[x]], id[x], 1, n, 1)) % p;
x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x, y);
sum = (sum + query(id[x], id[y], 1, n, 1)) % p;
return sum;
}
void changes(int x, int k)
{
k %= p;
update(id[x], id[x] + size[x] - 1, 1, n, 1, k);
}
ll qs(int x)
{
return query(id[x], id[x] + size[x] - 1, 1, n, 1);
}
int main()
{
cin >> n >> m >> root >> p;
for(int i = 1;i <= n;i ++) cin >> w[i];
for(int i = 1;i < n;i ++)
{
int u, v;
cin >> u >> v;
add(u, v), add(v, u);
}
dfs1(root, 0);
dfs2(root, root);
build(1, n, 1);
while(m --)
{
int op, x, y, z;
cin >> op;
if(op == 1)
{
cin >> x >> y >> z;
changel(x, y, z);
}
if(op == 2)
{
cin >> x >> y;
cout << ql(x, y) << endl;
}
if(op == 3)
{
cin >> x >> z;
changes(x, z);
}
if(op == 4)
{
cin >> x;
cout << qs(x) << endl;
}
}
return 0;
}
by xiaozhuo @ 2023-09-20 14:54:55
呜呜呜,update忘写pushdown了,我是蠢驴。害我又重写一遍,然后有调了半天才发现,太久没写线段树了受不了了,此贴结
by liuye20100123 @ 2024-08-25 23:14:49
@xiaozhuo 太感谢啦,我也错这儿了,调了半天,呜呜