Nuyoah_awa @ 2023-06-07 13:38:56
RT,Lca+开桶,不知道哪里挂了,10pts,有没有大佬帮忙调一下
code:
#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#define int long long
using namespace std;
const int N = 3e5 + 5, K = 20, D = 3e5;
struct node{
int s, t, Lca;
}p[N];
int n, m, deep[N], d[N], s[N], a[N], f[N][K + 5], w[N], box[N + N], ans[N];
vector <int> e[N];
vector <int> cnt1[N], cnt2[N], cnt3[N];
void dfs(int x, int dep, int fa)
{
deep[x] = dep;
f[x][1] = fa;
for(int i = 0;i < e[x].size();i++)
{
if(e[x][i] == fa)
continue;
dfs(e[x][i], dep + 1, x);
}
return ;
}
int lca(int x, int y)
{
if(deep[x] < deep[y])
swap(x, y);
for(int i = K;i >= 1;i--)
{
if(deep[f[x][i]] >= deep[y])
{
x = f[x][i];
}
}
if(x == y)
return x;
for(int i = K;i >= 1;i--)
{
if(f[x][i] != f[y][i])
x = f[x][i], y = f[y][i];
}
return f[x][1];
}
void dfs1(int x, int fa)
{
int tmp1 = box[deep[x]+w[x]+D];
for(int i = 0;i < cnt1[x].size();i++)
box[cnt1[x][i]+D]++;
for(int i = 0;i < e[x].size();i++)
{
if(e[x][i] == fa)
continue;
dfs1(e[x][i], x);
}
int tmp2 = box[deep[x]+w[x]+D];
ans[x] = tmp2 - tmp1;
for(int i = 0;i < cnt3[x].size();i++)
box[cnt3[x][i]+D]--;
return ;
}
void dfs2(int x, int fa)
{
int tmp1 = box[w[x]-deep[x]+D];
for(int i = 0;i < cnt2[x].size();i++)
box[cnt2[x][i]+D]++;
for(int i = 0;i < e[x].size();i++)
{
if(e[x][i] == fa)
continue;
dfs2(e[x][i], x);
}
for(int i = 0;i < cnt3[x].size();i++)
box[cnt3[x][i]+D]--;
int tmp2 = box[w[x]-deep[x]+D];
ans[x] += tmp2 - tmp1;
return ;
}
signed main()
{
scanf("%lld %lld", &n, &m);
for(int i = 1, u, v;i < n;i++)
{
scanf("%lld %lld", &u, &v);
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1, 1, 0);
for(int i = 2;i <= K;i++)
for(int j = 1;j <= n;j++)
f[j][i] = f[f[j][i-1]][i-1];
for(int i = 1;i <= n;i++)
scanf("%lld", &w[i]);
for(int i = 1;i <= m;i++)
{
scanf("%lld %lld", &p[i].s, &p[i].t);
p[i].Lca = lca(p[i].s, p[i].t);
cnt1[p[i].s].push_back(deep[p[i].s] + w[p[i].s]);
cnt2[p[i].t].push_back(deep[p[i].t] + w[p[i].t]);
cnt3[p[i].Lca].push_back(deep[p[i].t] + w[p[i].t]);
}
dfs1(1, 0);
for(int i = 0;i <= N + N;i++)
box[i] = 0;
dfs2(1, 0);
for(int i = 1;i <= n;i++)
printf("%lld ", ans[i]);
return 0;
}