80 分求助

P1600 [NOIP2016 提高组] 天天爱跑步

Others @ 2022-04-08 20:49:04

最后四个点挂了 qwq。

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N=2000005;
int n,q,root,nxt[N<<1],head[N],to[N<<1],cnt=0,fa[N],depth[N],top[N],sze[N],son[N],ans[N],tot1[N<<2],tot2[N<<2],w[N];
vector<int> vec1[N],vec2[N],vec3[N],vec4[N];
void add(int x,int y){
    nxt[++cnt]=head[x],head[x]=cnt,to[cnt]=y;
}
void dfs1(int i){
    depth[i]=depth[fa[i]]+1;
    sze[i]=1;
    for(int p=head[i];p;p=nxt[p]){
        if(to[p]==fa[i]) continue;
        fa[to[p]]=i;
        dfs1(to[p]);
        sze[i]+=sze[to[p]];
        if(son[i]==0||sze[son[i]]<sze[to[p]]) son[i]=to[p];
    }
    return ;
} 
void dfs2(int i,int Top){
    top[i]=Top;
    if(son[i]>0) dfs2(son[i],Top);
    for(int p=head[i];p;p=nxt[p]){
        if(to[p]==fa[i]||to[p]==son[i]) continue;
        dfs2(to[p],to[p]);
    } 
    return ; 
}
int getlca(int x,int y){
    while(top[x]!=top[y]){
        if(depth[x]>depth[y]) x=fa[top[x]];
        else y=fa[top[y]];
    }
    return depth[x]<depth[y]?(!x?1:x):(!y?1:y);
}
void dfs(int p,int cnt1,int cnt2) {
    for(int i=head[p];i;i=nxt[i]) 
        if(to[i]!=fa[p]) 
            dfs(to[i],tot1[w[to[i]]+depth[to[i]]+2*N],tot2[w[to[i]]-depth[to[i]]+2*N]);
    for(auto &lxl:vec1[p]) ++tot1[lxl];
    for(auto &lxl:vec2[p]) ++tot2[lxl];
    for(auto &lxl:vec3[p]) --tot1[lxl];
    for(auto &lxl:vec4[p]) --tot2[lxl];
//  cout << p << ":\n";
//  for(int i=N;i<=N+10;i++) printf("%d ",tot1[i]);
//  puts("");
//  for(int i=N;i<=N+10;i++) printf("%d ",tot2[i]);
//  puts("");
    ans[p]=tot1[w[p]+depth[p]+2*N]+tot2[w[p]-depth[p]+2*N]-cnt1-cnt2;
    return ;
}
signed main() {
//  freopen("P1600_17.in","r",stdin);
//  freopen("my","w",stdout);
    cin >> n >> q;
    root=1;
    int x,y;
    for(int i=1;i<n;++i){
        scanf("%lld%lld",&x,&y);
        add(x,y),add(y,x);
    }
    dfs1(root);
    dfs2(root,root);
    for(int i=1;i<=n;++i){
        scanf("%lld",&w[i]);
    }
    for(int i=1,dis;i<=q;i++) {
        scanf("%lld%lld",&x,&y),dis=getlca(x,y);
        vec1[x].push_back(depth[x]+2*N),vec2[y].push_back(depth[x]-2*depth[dis]+2*N),vec3[dis].push_back(depth[x]+2*N),vec4[fa[dis]].push_back(depth[x]-2*depth[dis]+2*N);
    }
    dfs(1,0,0);
    for(int i=1;i<=n;i++) printf("%lld ",ans[i]);
    return 0;
}

by Others @ 2022-04-09 07:54:47

受不了了,祖传 lca 爆了...


|