syLph @ 2023-12-30 11:32:04
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstring>
#include<memory.h>
#include<vector>
using namespace std;
const int maxn = 2000010;
typedef long long ll;
ll nxt[maxn],head[maxn],ver[maxn],tot = 1;
ll dep[maxn],fa[maxn],siz[maxn],son[maxn];
ll top[maxn],rnk[maxn],cnt,dfn[maxn],a[maxn],w[maxn];
ll lazy[maxn],dat[maxn];
ll n,m,root,mod;
void pushdown(ll p,ll l,ll r){
if(lazy[p]){
ll mid = (l + r) >> 1;
lazy[p*2] += lazy[p]; dat[p*2] += lazy[p] * (mid - l + 1);
lazy[p*2+1] += lazy[p]; dat[p*2+1] += lazy[p] * (r - mid);
lazy[p] = 0;
lazy[p*2] = lazy[p*2] % mod;
lazy[p*2+1] = lazy[p*2+1] % mod;
}
}
void build(ll p,ll l,ll r){
if(l == r){
dat[p] = w[l];
return;
}
ll mid = (l + r) >> 1;
if(l<=mid) build(p*2,l,mid);
if(mid+1<=r) build(p*2+1,mid+1,r);
dat[p] = dat[p*2] + dat[p*2+1];
dat[p] = dat[p] % mod;
}
void update(ll p,ll l,ll r,ll L,ll R,ll k){
if(l>R||r<L) return;
if(L<=l&&r<=R){
dat[p] += (r - l + 1) * k; lazy[p] += k;
dat[p] = dat[p] % mod;
pushdown(p,l,r);
return;
}
pushdown(p,l,r);
ll mid = (l + r) >> 1;
if(l<=mid) update(p*2,l,mid,L,R,k);
if(mid+1<=r) update(p*2+1,mid+1,r,L,R,k);
dat[p] = dat[p*2] + dat[p*2+1];
dat[p] = dat[p] % mod;
}
ll getsum(ll p,ll l,ll r,ll L,ll R){
if(l>R||r<L) return 0;
if(L<=l&&r<=R) {
return dat[p];
}
pushdown(p,l,r);
ll mid = (l + r) >> 1,ret = 0;
if(l<=mid) ret += getsum(p*2,l,mid,L,R);
if(mid+1<=r) ret += getsum(p*2+1,mid+1,r,L,R);
return ret % mod;
}
void add(ll x,ll y){
ver[++tot] = y;
nxt[tot] = head[x];
head[x] = tot;
}
void dfs1(ll x){
son[x] = -1,siz[x] = 1;
for(ll i = head[x] ; i ; i = nxt[i]){
ll y = ver[i];
if(!dep[y]&&y!=fa[x]){
dep[y] = dep[x] + 1;
fa[y] = x;
dfs1(y);
siz[x] += siz[y];
if(son[x]==-1&&siz[ver[i]]>siz[son[x]]) son[x] = ver[i];
}
}
}
void dfs2(ll x,ll t){
top[x] = t; dfn[x] = ++cnt; w[cnt] = a[x];
if(son[x] == -1) return;
dfs2(son[x],t);
for(ll i = head[x] ; i ; i = nxt[i])
if(ver[i]!=son[x]&&ver[i]!=fa[x])
dfs2(ver[i],ver[i]);
}
ll lca(ll x,ll y){
while(top[x] != top[y]){
if(dep[top[x]] > dep[top[y]]) x = fa[top[x]];
else y = fa[top[y]];
}
return dep[x] > dep[y] ? y : x;
}
ll query1(ll x,ll y){
ll ans = 0,res;
while(top[x] != top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans += getsum(1,1,n,dfn[top[x]],dfn[x]);
ans = (ans + mod) % mod;
x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x,y);
ans = (ans + getsum(1,1,n,dfn[x],dfn[y]) + mod) % mod;
return (ans + mod) % mod;
}
void query2(ll x,ll y,ll k){
k = (k + mod) % mod;
while(top[x] != top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
update(1,1,n,dfn[top[x]],dfn[x],k);
x = fa[top[x]];
}
if(dep[x] > dep[y]) swap(x,y);
update(1,1,n,dfn[x],dfn[y],k);
}
ll query3(ll x){
ll ans = getsum(1,1,n,dfn[x],dfn[x]+siz[x]-1);
return (ans + mod) % mod;
}
void query4(ll x,ll k){
update(1,1,n,dfn[x],dfn[x]+siz[x]-1,k);
return;
}
int main(){
//freopen("testdata.in","r",stdin);
//freopen("testdata.out","w",stdout);
scanf("%lld%lld%lld%lld",&n,&m,&root,&mod);
for(ll i = 1 ; i <= n ; i ++) scanf("%lld",&a[i]);
for(ll i = 1 ; i <= n - 1 ; i ++){
ll u,v; scanf("%lld%lld",&u,&v);
add(u,v),add(v,u);
}
dep[root] = 1;
dfs1(root);
dfs2(root,root);
build(1,1,n);
for(ll i = 0 ; i < m ; i ++){
ll opt; scanf("%lld",&opt);
if(opt == 1){
ll x,y,z; scanf("%lld%lld%lld",&x,&y,&z);
query2(x,y,z);
}else if(opt == 2){
ll x,y; scanf("%lld%lld",&x,&y);
printf("%lld\n",query1(x,y));
}else if(opt == 3){
ll x,y; scanf("%lld%lld",&x,&y);
query4(x,y);
}else{
ll x; scanf("%lld",&x);
printf("%lld\n",query3(x));
}
}
return 0;
}
不知道为什么…… 8,9,10 T 了
by _zzzzzzy_ @ 2023-12-30 11:33:10
@syLph 别都用ll试试,空间也别开那莫大
by syLph @ 2023-12-30 12:11:39
@zzzzzzy 似乎还是不行
by _zzzzzzy_ @ 2023-12-30 12:15:08
@syLph 根据你写的线段数函数有很多错误的,但是运行出来是对的,我估计是没有学得很好,然后错误大概率在线段树上