mc360 @ 2024-03-21 13:44:36
本人的:
#include<bits/stdc++.h>
#define N 500005
using namespace std;
struct node
{
int sum,l,r,flag;
}t[4*N];
struct edge
{
int nxt,to;
}e[2*N];
int n,a[N],m,r,root,head[N],cnt,dfn,q,x,y,k,dep[N],id[N],rnk[N],fa[N],sz[N],top[N],dson[N];
void dfs1(int u,int f)
{
dep[u]=dep[f]+1;
sz[u]=1;
fa[u]=f;
for(int i=head[u];i;i=e[i].nxt)
{
int v=e[i].to;
if(v==f)
continue;
dfs1(v,u);
sz[u]+=sz[v];
sz[u]%=q;
if(sz[dson[u]]<sz[v])
dson[u]=v;
}
}
void dfs2(int u,int _top)
{
top[u]=_top;
id[u]=++dfn;
rnk[dfn]=u;
if(dson[u]==0)
return ;
dfs2(dson[u],_top);
for(int i=head[u];i;i=e[i].nxt)
{
int v=e[i].to;
if(v==fa[u]||v==dson[u])
continue;
dfs2(v,v);
}
}
void add(int x,int y)
{
e[++cnt].nxt=head[x];
head[x]=cnt;
e[cnt].to=y;
}
void push_down(int p)
{
t[p*2].flag+=t[p].flag;
t[p*2].flag%=q;
t[p*2+1].flag+=t[p].flag;
t[p*2+1].flag%=q;
t[p*2].sum+=(t[p*2].r-t[p*2].l+1)*t[p].flag;
t[p*2].sum%=q;
t[p*2+1].sum+=(t[p*2+1].r-t[p*2+1].l+1)*t[p].flag;
t[p*2+1].sum%=q;
t[p].flag=0;
}
void push_up(int p)
{
t[p].sum=t[p*2].sum+t[p*2+1].sum;
t[p].sum%=q;
}
void build(int p,int l,int r)
{
t[p].l=l,t[p].r=r;
if(l==r)
{
t[p].flag=t[p].sum=a[rnk[l]];
return ;
}
int mid=(l+r)/2;
build(p*2,l,mid);
build(p*2+1,mid+1,r);
t[p].sum=t[p*2].sum+t[p*2+1].sum;
t[p].sum%=q;
}
int query(int p,int l,int r)
{
if(t[p].r<l||t[p].l>r)
return 0;
if(t[p].l>=l&&t[p].r<=r)
{
return t[p].sum%q;
}
push_down(p);
return (query(p*2,l,r)+query(p*2+1,l,r))%q;
}
void updata(int p,int l,int r,int v)
{
if(t[p].r<l||t[p].l>r)
return ;
if(t[p].l>=l&&t[p].r<=r)
{
t[p].flag+=v;
t[p].flag%=q;
t[p].sum+=(t[p].r-t[p].l+1)*v;
t[p].sum%=q;
return ;
}
push_down(p);
updata(p*2,l,r,v);
updata(p*2+1,l,r,v);
push_up(p);
}
void change(int x,int y,int v)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])
swap(x,y);
updata(1,id[top[x]],id[x],v);
x=fa[top[x]];
}
if(dep[x]>dep[y])
swap(x,y);
updata(1,id[x],id[y],v);
}
int get(int x,int y)
{
int ral=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])
swap(x,y);
ral+=query(1,id[top[x]],id[x]);
ral%=q;
x=fa[top[x]];
}
if(dep[x]>dep[y])
swap(x,y);
ral+=query(1,id[x],id[y]);
ral%=q;
return ral;
}
int main()
{
cin >> n >> m >> r >> q;
for(int i=1;i <=n;i++)
cin >> a[i];
for(int i=1;i < n;i++)
{
int x,y;
cin >> x >> y;
add(x,y);
add(y,x);
}
dfs1(root,0);
dfs2(root,root);
build(1,1,n);
for(int i=1;i<=m;i++)
{
int op,x,y,z;
cin>>op;
if(op==1)
{
cin>>x>>y>>z;
change(x,y,z);
}
if(op==2)
{
cin>>x>>y;
cout<<get(x,y)<<endl;
}
if(op==3)
{
cin>>x>>z;
updata(1,id[x],id[x]+sz[x]-1,z);
}
if(op==4)
{
cin>>x;
cout<<query(1,id[x],id[x]+sz[x]-1)<<endl;
}
}
return 0;
}
老师的:
#include<bits/stdc++.h>
#define N 100005
using namespace std;
struct node{
int to,nxt;
}e[N*2];
int n,m,root,cnt,dfn,mod,rnk[N],w[N],fa[N],head[N],dep[N],sz[N],top[N],id[N],dson[N];
void add(int x,int y){
e[++cnt].nxt=head[x]; head[x]=cnt; e[cnt].to=y;
}
void dfs1(int u,int f){
dep[u]=dep[f]+1; sz[u]=1 ; fa[u]=f;
for(int i=head[u];i!=0;i=e[i].nxt){
int v=e[i].to;
if(v==f) continue;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[dson[u]]<sz[v]) dson[u]=v;
}
}
void dfs2(int u,int _top)
{
top[u]=_top; id[u]=++dfn; rnk[dfn]=u;
if(dson[u]==0) return ;
dfs2(dson[u],_top);
for(int i=head[u];i!=0;i=e[i].nxt){
int v=e[i].to;
if(v==fa[u]||v==dson[u]) continue;
dfs2(v,v);
}
}
struct node2{
int l,r,tg,sum;
}seg[N*4];
void build(int p,int l,int r){
int mid=(l+r)>>1;
seg[p].l=l; seg[p].r=r;
if(l==r) {
seg[p].tg=seg[p].sum=( w[ rnk[l] ] );
return ;
}
build(p<<1,l,mid);
build(p<<1|1,mid+1,r);
seg[p].sum=seg[p<<1].sum+seg[p<<1|1].sum;
}
void pushdown(int p){
seg[2*p].tg+=seg[p].tg;
seg[2*p].tg%=mod;
seg[2*p].sum+=(seg[2*p].r-seg[2*p].l+1)*seg[p].tg;
seg[2*p].sum%=mod;
seg[2*p+1].tg+=seg[p].tg;
seg[2*p+1].tg%=mod;
seg[2*p+1].sum+=(seg[2*p+1].r-seg[2*p+1].l+1)*seg[p].tg;
seg[2*p+1].sum%=mod;
seg[p].tg=0;
}
void update(int p,int l,int r,int val){
if(seg[p].l>=l&&seg[p].r<=r) {
seg[p].tg+=val;
seg[p].tg%=mod;
seg[p].sum+= (seg[p].r-seg[p].l+1)*val;
seg[p].sum%=mod;
return ;
}
if(seg[p].l>r||seg[p].r<l ) return ;
pushdown(p);
update(2*p,l,r,val);
update(2*p+1,l,r,val);
seg[p].sum=seg[p<<1].sum+seg[p<<1|1].sum;
seg[p].sum%=mod;
}
int quary(int p,int l,int r){
if(seg[p].l>=l&&seg[p].r<=r) {
return seg[p].sum%mod;
}
if(seg[p].l>r||seg[p].r<l ) return 0;
pushdown(p);
return (quary(2*p,l,r)+quary(2*p+1,l,r))%mod;
}
void change(int x,int y,int val){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
update(1,id[top[x]],id[x],val);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
update(1,id[x],id[y],val);
}
int get(int x,int y)
{
int rel=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
rel+=quary(1,id[top[x]],id[x]);
rel%=mod;
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
rel+=quary(1,id[x],id[y]);
rel%=mod;
return rel;
}
int main()
{
cin>>n>>m>>root>>mod;
for(int i=1;i<=n;i++) cin>>w[i];
for(int i=1;i<n;i++){
int x,y; cin>>x>>y;
add(x,y); add(y,x);
}
dfs1(root,0);
dfs2(root,root);
build(1,1,n);
for(int i=1;i<=m;i++){
int op,x,y,z;
cin>>op;
if(op==1){
cin>>x>>y>>z;
change(x,y,z);
}
if(op==2){
cin>>x>>y;
cout<<get(x,y)<<endl;
}
if(op==3){
cin>>x>>z;
update(1,id[x],id[x]+sz[x]-1,z);
}
if(op==4){
cin>>x;
cout<<quary(1,id[x],id[x]+sz[x]-1)<<endl;
}
}
return 0;
}
by Rosaya @ 2024-03-21 13:55:11
你的 root 是不是根本没值?