Svemit @ 2022-12-09 13:53:11
T飞了30pts
#include<bits/stdc++.h>
using namespace std;
const int N=1e4+5;
int n,m,cnt,tot,s,root,k,ans;
int head[N],size[N],f[N],d[N],dep[N];
bool vis[N];
struct edge
{
int nex,to,w;
}e[N<<1];
inline void add_edge(int u,int v,int w)
{
e[++cnt].to=v;
e[cnt].w=w;
e[cnt].nex=head[u];
head[u]=cnt;
}
inline void get_root(int u,int fa)
{
size[u]=1;
f[u]=0;
for(int i=head[u];i;i=e[i].nex)
{
int v=e[i].to;
if(v==fa||vis[v]) continue;
get_root(v,u);
size[u]+=size[v];
f[u]=max(f[u],size[v]);
}
f[u]=max(f[u],s-f[u]);
if(f[u]<f[root])
root=u;
}
inline void get_dep(int u,int fa)
{
dep[++tot]=d[u];
for(int i=head[u];i;i=e[i].nex)
{
int v=e[i].to,w=e[i].w;
if(v==fa||vis[v]||d[u]+w>k) continue;
d[v]=d[u]+w;
get_dep(v,u);
}
}
inline int get_sum(int u,int dis)
{
d[u]=dis;
tot=0;
int sum=0;
get_dep(u,0);
sort(dep+1,dep+1+tot);
int l=1,r=tot;
while(l<r)
{
if(dep[l]+dep[r]<k) l++;
else
if(dep[l]+dep[r]>k) r--;
else
{
if(dep[l]==dep[r])
{
sum+=(r-l+1)*(r-l)/2;
break;
}
int st=l,ed=r;
while(dep[st]==dep[l])
st++;
while(dep[ed]==dep[r])
ed--;
sum+=(st-l)*(r-ed);
l=st;
r=ed;
}
}
return sum;
}
inline void solve(int u)
{
vis[u]=true;
ans+=get_sum(u,0);
for(int i=head[u];i;i=e[i].nex)
{
int v=e[i].to,w=e[i].w;
if(vis[v]) continue;
ans-=get_sum(v,w);
root=0;
s=size[v];
get_root(v,u);
solve(v);
}
}
int main()
{
std::ios::sync_with_stdio(false);
std::cin.tie(NULL);
std::cout.tie(NULL);
f[0]=0x3f3f3f3f;
cin>>n>>m;
for(int i=1;i<n;i++)
{
int u,v,w;
cin>>u>>v>>w;
add_edge(u,v,w);
add_edge(v,u,w);
}
while(m--)
{
cin>>k;
memset(vis,false,sizeof(vis));
root=0;
s=n;
ans=0;
get_root(1,0);
solve(root);
if(ans)
cout<<"AYE\n";
else
cout<<"NAY\n";
}
return 0;
}