AmiyaCast @ 2023-07-19 17:01:29
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<vector>
#include<map>
#include<queue>
#include<algorithm>
using namespace std;
const int N = 1e5 + 7;
#define ll long long
int son[N], dfn[N], fa[N], dep[N], siz[N], top[N], ed[N];
int head[N << 1], nxt[N << 1], to[N << 1], cnt, val[N << 1], tot, rnk[N];
ll a[N];
void add(int x, int y)
{
nxt[++cnt] = head[x];
to[cnt] = y;
head[x] = cnt;
}
#define ll long long
struct Node{
int l, r;
ll sum;
ll add;
}t[N << 2];
void up(Node &u, Node l, Node r)
{
u.sum = l.sum + r.sum;
}
void up(int p)
{
up(t[p], t[p << 1], t[p << 1 | 1]);
}
ll mod;
void down(int p)
{
if(t[p].add)
{
ll tmp = t[p].add;
t[p << 1].sum += tmp * (t[p << 1].r - t[p << 1].l + 1);
t[p << 1 | 1].sum += tmp * (t[p << 1 | 1].r - t[p << 1 | 1].l + 1);
t[p << 1].add += tmp;
t[p << 1 | 1].add += tmp;
t[p << 1].sum %= mod;
t[p << 1 | 1].sum %= mod;
t[p].add = 0;
}
}
inline void build(int p, int l, int r)
{
t[p] = Node{l, r};
if(l == r)
{
t[p].sum = a[dfn[l]];
return;
}
const int mid = (l + r) >> 1;
build(p << 1, l, mid);
build(p << 1 | 1, mid + 1, r);
up(p);
}
void ch(int p, int x, int y, ll k)
{
int l = t[p].l, r = t[p].r;
// cout << x << " / " << y << endl;
// cout << l << " & " << r << endl;
if(x <= l && r <= y)
{
t[p].add += k;
t[p].sum += (r - l + 1) * k;
return ;
}
down(p);
const int mid = (l + r) >> 1;
if(x <= mid)
ch(p << 1, x, y, k);
if(y >= mid + 1)
ch(p << 1 | 1, x, y, k);
up(p);
}
ll ask(int p, int x, int y)
{
int l = t[p].l, r = t[p].r;
if(x <= l && r <= y)
{
return t[p].sum;
}
down(p);
ll ans = 0;
const int mid = (l + r) >> 1;
if(x <= mid)
ans += ask(p << 1, x, y);
if(y >= mid + 1)
ans += ask(p << 1 | 1, x, y);
return ans;
}
void dfs1(int x) {//更新 fa son siz dep
son[x] = -1;
siz[x] = 1;
for(int i = head[x]; i; i = nxt[i])
if (!dep[to[i]]){
dep[to[i]] = dep[x] + 1;
fa[to[i]] = x;
dfs1(to[i]);
siz[x] += siz[to[i]];
if (son[x] == -1 || siz[to[i]] > siz[son[x]])
son[x] = to[i];
}
}
void dfs2(int x, int t) { // top dfn rnk
top[x] = t; //top
tot++;
dfn[x] = tot;//dfn
ed[x] = tot;//end
rnk[tot] = x;//rnk
if (son[x] == -1)
return;
dfs2(son[x], t); // 优先对重儿子进行 DFS,可以保证同一条重链上的点 DFS 序连续
for (int i = head[x]; i; i = nxt[i])
if (to[i] != son[x] && to[i] != fa[x])
{
dfs2(to[i], to[i]);
ed[x] = max(ed[x], ed[to[i]]);
}
ed[x] = max(ed[x], ed[son[x]]);
}
int n, m, r;
ll p;
inline ll read()
{
ll x=0,f=1;
char c=getchar();
while (c<'0' || c>'9')
{
if (c=='-') f=-1;
c=getchar();
}
while (c>='0' && c<='9')
{
x=x*10+c-'0';
c=getchar();
}
return x*f;
}
void work1(int x, int y, ll z)
{
int fx = top[x], fy = top[y];
while(fx != fy)
{
// cout << fy << " " << fx << endl;
if(dep[fx] > dep[fy])
{
ch(1, dfn[fx], dfn[x], z);
x = fa[fx];
}else{
ch(1, dfn[fy], dfn[y], z);
y = fa[fy];
// puts("----------------------");
}
fx = top[x], fy = top[y];
}
if(dep[x] < dep[y]) ch(1, dfn[x], dfn[y], z);
else ch(1, dfn[x], dfn[y], z);
}
ll work2(int x, int y)
{
ll Ans = 0;
int fx = top[x], fy = top[y];
while(fx != fy)
{
// cout << x << " " << y << endl;
// cout << fx << " " << fy << endl;
if(dep[fx] > dep[fy])
{
Ans += ask(1, dfn[fx], dfn[x]) % mod;//! 处理的一定是dfn!
x = fa[fx];
}else{
Ans += ask(1, dfn[fy], dfn[y]) % mod;
y = fa[fy];
}
fx = top[x];
fy = top[y];
Ans %= mod;
}
if(dep[x] < dep[y]) Ans += ask(1, dfn[x], dfn[y]) % mod;
else Ans += ask(1, dfn[y], dfn[x]) % mod;
return Ans % mod;
}
void work3(int x, ll z)
{
// cout << "3 is normal start." << endl;
// cout << dfn[x] << end[x] << endl;
ch(1, dfn[x], ed[x], z);
// puts("3 is normal end");
}
ll work4(int x)
{
return ask(1, dfn[x], ed[x]) % mod;
}
int main(){
n = read(), m = read(), r = read(), mod = read();
for(int i = 1; i <= n; ++i)
a[i] = read();
for(int i = 1; i < n; ++i)
{
int x = read(), y = read();
add(x, y);
add(y, x);
}
dep[r] = 1;
dfs1(r);
dfs2(r, r);
// puts(" --- start of check_dfn");
// for(int i = 1; i <= n; ++i)
// {
// cout << dfn[i] << " " << end[i] << endl;
// }puts(" ---end of check_dfn");
build(1, 1, n);
// for(int i = 1; i <= 10; ++i)
// {
// cout << t[i].l << " " << t[i].r << endl;
// }puts(" ---end of check_l&r");
// for(int i = 1; i <= n; ++i)
// {
// cout << fa[i] << " ";
// }puts(" ---end of check_fa");
for(int i = 1; i <= m; ++i)
{
// cout << "operation " << i << ": " << endl;
ll opt, x, y, z;
opt = read();
if(opt == 1)
{
x = read(), y = read(), z = read();
work1(x, y, z);
}
if(opt == 2)
{
x = read(), y = read();
cout << work2(x, y) << endl;
}
if(opt == 3){
x = read(), z = read();
work3(x, z);
}
if(opt == 4)
{
x = read();
cout << work4(x) << endl;
}
}
return 0;
}
by AmiyaCast @ 2023-07-19 18:05:04
过了点10
by AmiyaCast @ 2023-07-19 18:11:16
不是点10 是点11
by AmiyaCast @ 2023-07-19 19:35:11
de出来了,work1 有误 线段树有误 此帖终