萌萌猪头求帮助

P1600 [NOIP2016 提高组] 天天爱跑步

梦语小猪头 @ 2021-08-11 08:55:24

45分哭哭,我看提交记录有人跟我一模一样的数据点WA了,求帮助哭哭

#include<map>
#include<vector>
#include<cstdio>
#include<iostream>
#define N 300017
#define fir(i) i.first
#define sec(i) i.second
#define mk(i,j) make_pair(i,j)
#define uint unsigned int
using namespace std;

struct edge{int v,next;}e[N<<1];
int n,m,tot,num,head[N],w[N],dfn[N],rev[N],l[N],r[N],dep[N],fa[N][23],ans[N];
map<int,int>F,G;vector<pair<int,int> >x[N],y[N],res[N];

void add(int u,int v)
{
    e[++tot].v = v;
    e[tot].next = head[u];
    head[u] = tot;
}

int getlca(int x,int y)
{
    if(dep[x] < dep[y])swap(x,y);
    for(int i = 21;i >= 0;i -= 1)
        if(dep[fa[x][i]] >= dep[y])x = fa[x][i];
    for(int i = 21;i >= 0;i -= 1)
        if(fa[x][i] != fa[y][i])x = fa[x][i],y = fa[y][i];
    if(x == y)return x;
    else return fa[x][0];
}

void dfs1(int u,int f)
{
    fa[u][0] = f;dep[u] = dep[f] + 1;
    dfn[u] = ++num;rev[num] = u;l[u] = r[u] = dfn[u];
    for(int i = 1;i <= 21;i += 1)
        fa[u][i] = fa[fa[u][i-1]][i-1];
    for(int i = head[u];i;i = e[i].next)
    {
        int v = e[i].v;
        if(v == f)continue;
        dfs1(v,u);r[u] = max(r[u],r[v]);
    }
}

int main()
{
    //freopen("1600.in","r",stdin);freopen("1600.out","w",stdout);
    scanf("%d%d",&n,&m);
    for(int i = 1,u,v;i < n;i += 1)
    {
        scanf("%d%d",&u,&v);
        add(u,v);add(v,u);
    }
    dfs1(1,0);
    for(int i = 1;i <= n;i += 1){
        scanf("%d",&w[i]);
        res[l[i]].push_back(mk(i,1));res[r[i]+1].push_back(mk(i,-1));
    }
    for(int i = 1,s,t,lca,flca;i <= m;i += 1)
    {
        scanf("%d%d",&s,&t);
        lca = getlca(s,t);flca = fa[lca][0];
        x[s].push_back(mk(0,1));x[lca].push_back(mk(dep[s]-dep[lca],-1));
        y[t].push_back(mk(dep[s]+dep[t]-2*dep[lca],1));y[flca].push_back(mk(dep[s]-dep[flca],-1));
    }
    for(int i = n;i >= 1;i -= 1)
    {
        int u = rev[i];
        for(uint j = 0;j < x[u].size();j += 1)
        {
            int t = fir(x[u][j]),v = sec(x[u][j]);
            F[dep[u]+t] += v;
        }
        for(uint j = 0;j < y[u].size();j += 1)
        {
            int t = fir(y[u][j]),v = sec(y[u][j]);
            G[t-dep[u]] += v;
        }
        for(uint j = 0;j < res[i].size();j += 1)
        {
            int t = fir(res[i][j]),v = sec(res[i][j]);
            ans[t] += v*(F[w[t]+dep[t]]+G[w[t]-dep[t]]);
        }
    }
    for(int i = 1;i <= n;i += 1)
        printf("%d ",ans[i]);
    return 0;
}

|