SilverLi @ 2023-05-13 18:11:16
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=1e6+5;
int n,m,r,mod,ax[N];
int dfn[N],Index,a[N];
int fa[N],d[N],top[N];
int son[N],si[N];
vector<int> g[N];
int t[N],ad1[N],ad2[N];
void add1(int p,int l,int r) {
if(ad1[p]) {
t[p<<1]+=ad1[p]*((r-l+1)>>1);
t[p<<1|1]+=ad1[p]*((r-l+1)>>1);
ad1[p<<1]+=ad1[p];
ad1[p<<1|1]+=ad1[p];
ad1[p]=0;
}
}
void add2(int p,int l,int r) {
if(ad2[p]) {
t[p<<1]=ad2[p]*((r-l+1)>>1);
t[p<<1|1]=ad2[p]*((r-l+1)>>1);
ad2[p<<1]=ad2[p];
ad2[p<<1|1]=ad2[p];
ad2[p]=0;
ad1[p<<1]=ad1[p<<1|1]=0;
}
}
void down(int p,int l,int r) {add2(p,l,r),add1(p,l,r);}
void build(int l,int r,int p) {
if(l==r) { t[p]=a[l]; return; }
int m=l+r>>1;
build(l,m,p<<1);build(m+1,r,p<<1|1);
t[p]=(t[p<<1]+t[p<<1|1]+mod)%mod;
}
void ADD(int l,int r,int S,int T,int p,int ch) {
if(l>=S&&r<=T) {
t[p]+=ch*(r-l+1),ad1[p]+=ch;
return;
}
int m=l+r>>1;
down(p,l,r);
if(S<=m) ADD(l,m,S,T,p<<1,ch);
if(T>m) ADD(m+1,r,S,T,p<<1|1,ch);
t[p]=(t[p<<1]+t[p<<1|1]+mod)%mod;
return;
}
void COVER(int l,int r,int S,int T,int p,int ch) {
if(l>=S&&r<=T) {
t[p]=ch*(r-l+1),
ad2[p]=ch,
ad1[p]=0;
return;
}
int m=l+r>>1;
down(p,l,r);
if(S<=m) COVER(l,m,S,T,p<<1,ch);
if(T>m) COVER(m+1,r,S,T,p<<1|1,ch);
t[p]=(t[p<<1]+t[p<<1|1]+mod)%mod;
return;
}
int SUM(int l,int r,int S,int T,int p) {
if(l>=S&&r<=T) return t[p];
int m=l+r>>1,sum=0;
down(p,l,r);
if(S<=m) sum=SUM(l,m,S,T,p<<1)%mod;
if(T>m) sum+=SUM(m+1,r,S,T,p<<1|1);
return (sum+mod)%mod;
}
void dfs(int v,int f) {
d[v]=d[f]+1,si[v]=1,fa[v]=f;
int mx=0;
for(int i:g[v])
if(i!=f) {
dfs(i,v),si[v]+=si[i];
if(mx<si[i]) son[v]=i,mx=si[i];
}
}
void dfs2(int v,int deep) {
dfn[v]=++Index,a[Index]=ax[v]%mod;
top[v]=deep;
if(!son[v]) return;
dfs2(son[v],deep);
for(int i:g[v])
if(i!=fa[v]&&i!=son[v]) dfs2(i,i);
}
inline void add(int u,int v,int ch) {
while(top[u]!=top[v]) {
if(d[top[u]]<d[top[v]]) swap(u,v);
ADD(1,n,dfn[top[u]],dfn[u],1,ch);
u=fa[top[u]];
}
if(d[u]>d[v]) swap(u,v);
ADD(1,n,dfn[u],dfn[v],1,ch);
}
inline void addsub(int u,int ch) {ADD(1,n,dfn[u],dfn[u]+si[u]-1,1,ch);}
inline int sum(int u,int v) {
int res=0;
while(top[u]!=top[v]) {
if(d[top[u]]<d[top[v]]) swap(u,v);
res+=SUM(1,n,dfn[top[u]],dfn[u],1);res%=mod;
u=fa[top[u]];
}
if(d[u]>d[v]) swap(u,v);
res+=SUM(1,n,dfn[u],dfn[v],1);
return res%mod;
}
inline int sumsub(int u) {return SUM(1,n,dfn[u],dfn[u]+si[u]-1,1)%mod;}
signed main() {
cin>>n>>m>>r>>mod;
for(int i=1;i<=n;++i) cin>>ax[i];
for(int i=1;i<n;++i) {int x,y;cin>>x>>y;g[x].push_back(y),g[y].push_back(x);}
dfs(r,0);dfs2(r,r);
build(1,n,1);
while(m--) {
int opt,x,y,z;
cin>>opt>>x;
if(opt==1) {
cin>>y>>z;
add(x,y,z);
} else if(opt==2) {
cin>>y;
cout<<sum(x,y)<<endl;
} else if(opt==3) {
cin>>z;
addsub(x,z);
} else cout<<sumsub(x)<<endl;
}
return 0;
}
by Killer_joke @ 2023-05-14 22:00:26
@NM_ljy add1写的有问题,左右儿子的区间长度不一定是直接除二。
by SilverLi @ 2023-05-14 22:03:38
@Killer_joke thx,+已A
我在改以前用树状数组写的**代码