a_sad_soul @ 2023-03-13 17:41:53
照着模板稍微修改了一下还是不行
#include<bits/stdc++.h>
#define MAXN 1000005
using namespace std;
struct node{
int to,val;
node *next;
}*head[MAXN],*tmp;
void add(int u,int v,int val)
{
tmp=new node;
tmp->val=val;
tmp->to=v;
tmp->next=head[u];
head[u]=tmp;
}
int n,k;
bool visit[MAXN],ans[100000005];
int S;
int root;
int size[MAXN],mx[MAXN];
void getroot(int u,int fa)
{
size[u]=1;
mx[u]=0;
//int maxx=-1;
for(node *i=head[u];i!=NULL;i=i->next)
{
int v=i->to;
if(v==fa||visit[v])continue;
getroot(v,u);
size[u]+=size[v];
//maxx=max(size[v],maxx);
mx[u]=max(mx[u],size[v]);
}
//maxx=max(S-size[u],maxx);
mx[u]=max(mx[u],S-size[u]);
if(mx[u]<mx[root])
root=u;
// return ;
}
int dis[MAXN],wei[MAXN],tot;
void getdis(int u,int fa,int len)
{
dis[++tot]=wei[u];
for(node *i=head[u];i!=NULL;i=i->next)
{
int v=i->to,val=i->val;
if(v==fa||visit[v])continue;
wei[v]=len+val;
getdis(v,u,len+val);
}
}
void solve(int u,int len,int k)
{
tot=0;
wei[u]=len;
getdis(u,0,len);
for(int i=1;i<=tot;++i)
for(int j=1;j<=tot;++j)
if(i!=j)
ans[dis[i]+dis[j]]+=k;
}
void divide(int u)
{
solve(u,0,1);
visit[u]=1;
//S=size[u];
for(node *i=head[u];i!=NULL;i=i->next)
{
int v=i->to;
if(visit[v])continue;
solve(v,i->val,-1);
S=size[u];
root=0,mx[0]=n;
getroot(v,u);
divide(root);
}
}
int m;
int main()
{
cin>>n>>m;
for(int i=1;i<n;++i)
{
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
add(u,v,w);
add(v,u,w);
}
S=n,mx[0]=n,root=0;
getroot(1,0);
divide(root);
for(int i=1;i<=m;++i)
{
int k;
cin>>k;
if(ans[k])printf("AYE\n");
else printf("NAY\n");
}
return 0;
}
by 5k_sync_closer @ 2023-03-13 18:47:43
@a_sad_soul