wudiss8 @ 2020-09-26 17:25:39
按照第一篇题解的思路来写的代码,挂了#1,#9~13这些点,求大佬帮忙看看错在哪
#include<bits/stdc++.h>
using namespace std;
const int SIZE=300000;
int tot,next[610001],poi[610001],to[610001];
int tot1,next1[610001],poi1[610001],to1[610001];
int tot2,next2[610001],poi2[610001],to2[610001];
int fa[610001][21],dep[610001];
int b1[610001],b2[610001],w[610001],js[610001],dis[610001];
int s[610001],t[610001];
int ans[610001];
inline void addt(int x,int y){
tot++;
next[tot]=poi[x];poi[x]=tot;to[tot]=y;
}
inline void add1(int x,int y){
tot1++;
next1[tot1]=poi1[x];poi1[x]=tot1;to1[tot1]=y;
}
inline void add2(int x,int y){
tot2++;
next2[tot2]=poi2[x];poi2[x]=tot2;to2[tot2]=y;
}
inline void dfs(int x,int fat){
fa[x][0]=fat;
dep[x]=dep[fat]+1;
for(register int i=1;i<=20;i++)
fa[x][i]=fa[fa[x][i-1]][i-1];
for(register int e=poi[x];e;e=next[e]){
if(to[e]==fat)continue;
dfs(to[e],x);
}
}
inline int glca(int x,int y){
if(dep[y]>dep[x])swap(x,y);
for(register int i=20;i>=0;i--){
if(dep[fa[x][i]]>=dep[y]){
x=fa[x][i];
}
}
if(x==y)return x;
for(register int i=20;i>=0;i--){
if(fa[x][i]!=fa[y][i]){
x=fa[x][i];
y=fa[y][i];
}
}
return fa[x][0];
}
inline void dfs2(int x,int fat){
int t1=b1[dep[x]+w[x]],t2=b2[w[x]-dep[x]+SIZE];
for(register int e=poi[x];e;e=next[e]){
if(to[e]==fat)continue;
dfs2(to[e],x);
}
b1[dep[x]]=b1[dep[x]]+js[x];
for(register int e=poi1[x];e;e=next1[e]){
b2[dis[to1[e]]-dep[t[to1[e]]]+SIZE]++;
}
ans[x]=ans[x]+b1[dep[x]+w[x]]-t1+b2[w[x]-dep[x]+SIZE]-t2;
for(register int e=poi2[x];e;e=next2[e]){
b1[s[to2[e]]]--;
b2[dis[to2[e]]-dep[t[to2[e]]]+SIZE]--;
}
}
inline int read(){
char c=getchar();
int s=0,f=1;
while(c<'0' or c>'9'){
if(c=='-')f=-1;
c=getchar();
}
while(c>='0' and c<='9'){
s=(s<<1)+(s<<3)+c-'0';
c=getchar();
}
return s*f;
}
int main(){
int n,m;
n=read();m=read();
for(register int i=1;i<n;i++){
int u,v;
u=read();v=read();
addt(u,v);
addt(v,u);
}
for(register int i=1;i<=n;i++){
w[i]=read();
}
dfs(1,0);
for(register int i=1;i<=m;i++){
s[i]=read();t[i]=read();
int lca=glca(s[i],t[i]);
dis[i]=dep[s[i]]+dep[t[i]]-2*dep[lca];
js[s[i]]++;
add1(t[i],i);
add2(lca,i);
if(dep[lca]+w[lca]==dep[s[i]])ans[lca]--;
}
dfs2(1,0);
for(register int i=1;i<=n;i++)
printf("%d ",ans[i]);
printf("\n");
return 0;
}