70pts的LCA求助!

P1600 [NOIP2016 提高组] 天天爱跑步

wudiss8 @ 2020-09-26 17:25:39

按照第一篇题解的思路来写的代码,挂了#1,#9~13这些点,求大佬帮忙看看错在哪

#include<bits/stdc++.h>
using namespace std;
const int SIZE=300000;
int tot,next[610001],poi[610001],to[610001];
int tot1,next1[610001],poi1[610001],to1[610001];
int tot2,next2[610001],poi2[610001],to2[610001];
int fa[610001][21],dep[610001];
int b1[610001],b2[610001],w[610001],js[610001],dis[610001];
int s[610001],t[610001];
int ans[610001];
inline void addt(int x,int y){
    tot++;
    next[tot]=poi[x];poi[x]=tot;to[tot]=y;
}
inline void add1(int x,int y){
    tot1++;
    next1[tot1]=poi1[x];poi1[x]=tot1;to1[tot1]=y;
}
inline void add2(int x,int y){
    tot2++;
    next2[tot2]=poi2[x];poi2[x]=tot2;to2[tot2]=y;
}
inline void dfs(int x,int fat){
    fa[x][0]=fat;
    dep[x]=dep[fat]+1;
    for(register int i=1;i<=20;i++)
    fa[x][i]=fa[fa[x][i-1]][i-1];
    for(register int e=poi[x];e;e=next[e]){
        if(to[e]==fat)continue;
        dfs(to[e],x);
    }
}
inline int glca(int x,int y){
    if(dep[y]>dep[x])swap(x,y);
    for(register int i=20;i>=0;i--){
        if(dep[fa[x][i]]>=dep[y]){
            x=fa[x][i];
        }
    }
    if(x==y)return x;
    for(register int i=20;i>=0;i--){
        if(fa[x][i]!=fa[y][i]){
            x=fa[x][i];
            y=fa[y][i];
        }
    }
    return fa[x][0];
}
inline void dfs2(int x,int fat){
    int t1=b1[dep[x]+w[x]],t2=b2[w[x]-dep[x]+SIZE];
    for(register int e=poi[x];e;e=next[e]){
        if(to[e]==fat)continue;
        dfs2(to[e],x);
    }
    b1[dep[x]]=b1[dep[x]]+js[x];
    for(register int e=poi1[x];e;e=next1[e]){
        b2[dis[to1[e]]-dep[t[to1[e]]]+SIZE]++;
    }
    ans[x]=ans[x]+b1[dep[x]+w[x]]-t1+b2[w[x]-dep[x]+SIZE]-t2;
    for(register int e=poi2[x];e;e=next2[e]){
        b1[s[to2[e]]]--;
        b2[dis[to2[e]]-dep[t[to2[e]]]+SIZE]--;
    }
}
inline int read(){
    char c=getchar();
    int s=0,f=1;
    while(c<'0' or c>'9'){
        if(c=='-')f=-1;
        c=getchar();
    }
    while(c>='0' and c<='9'){
        s=(s<<1)+(s<<3)+c-'0';
        c=getchar();
    }
    return s*f;
}
int main(){
    int n,m;
    n=read();m=read();
    for(register int i=1;i<n;i++){
        int u,v;
        u=read();v=read();
        addt(u,v);
        addt(v,u);
    }
    for(register int i=1;i<=n;i++){
        w[i]=read();
    }
    dfs(1,0);
    for(register int i=1;i<=m;i++){
        s[i]=read();t[i]=read();
        int lca=glca(s[i],t[i]);
        dis[i]=dep[s[i]]+dep[t[i]]-2*dep[lca];
        js[s[i]]++;
        add1(t[i],i);
        add2(lca,i);
        if(dep[lca]+w[lca]==dep[s[i]])ans[lca]--;
    }
    dfs2(1,0);
    for(register int i=1;i<=n;i++)
    printf("%d ",ans[i]);
    printf("\n");
    return 0;
}

|