求助

P3806 【模板】点分治 1

Svemit @ 2022-12-09 13:53:11

T飞了30pts

#include<bits/stdc++.h>
using namespace std;
const int N=1e4+5;
int n,m,cnt,tot,s,root,k,ans;
int head[N],size[N],f[N],d[N],dep[N];
bool vis[N];
struct edge
{
    int nex,to,w;
}e[N<<1];

inline void add_edge(int u,int v,int w)
{
    e[++cnt].to=v;
    e[cnt].w=w;
    e[cnt].nex=head[u];
    head[u]=cnt;
}

inline void get_root(int u,int fa)
{
    size[u]=1;
    f[u]=0;
    for(int i=head[u];i;i=e[i].nex)
    {
        int v=e[i].to;
        if(v==fa||vis[v]) continue;
        get_root(v,u);
        size[u]+=size[v];
        f[u]=max(f[u],size[v]);
    }
    f[u]=max(f[u],s-f[u]);
    if(f[u]<f[root])
      root=u;
}

inline void get_dep(int u,int fa)
{
    dep[++tot]=d[u];
    for(int i=head[u];i;i=e[i].nex)
    {
        int v=e[i].to,w=e[i].w;
        if(v==fa||vis[v]||d[u]+w>k) continue;
        d[v]=d[u]+w;
        get_dep(v,u);
    }
}

inline int get_sum(int u,int dis)
{
    d[u]=dis;
    tot=0;
    int sum=0;
    get_dep(u,0);
    sort(dep+1,dep+1+tot);
    int l=1,r=tot;
    while(l<r)
    {
        if(dep[l]+dep[r]<k) l++;
        else
          if(dep[l]+dep[r]>k) r--;
          else
          {
            if(dep[l]==dep[r])
            {
                sum+=(r-l+1)*(r-l)/2;
                break;
            }
            int st=l,ed=r;
            while(dep[st]==dep[l])
              st++;
            while(dep[ed]==dep[r])
              ed--;
             sum+=(st-l)*(r-ed);
            l=st;
            r=ed; 
          }
    }
    return sum;
}

inline void solve(int u)
{
    vis[u]=true;
    ans+=get_sum(u,0);
    for(int i=head[u];i;i=e[i].nex)
    {
        int v=e[i].to,w=e[i].w;
        if(vis[v]) continue;
        ans-=get_sum(v,w);
        root=0;
        s=size[v];
        get_root(v,u);
        solve(v);
    }
}

int main()
{
    std::ios::sync_with_stdio(false);
    std::cin.tie(NULL);
    std::cout.tie(NULL);
    f[0]=0x3f3f3f3f;
    cin>>n>>m;
    for(int i=1;i<n;i++)
    {
        int u,v,w;
        cin>>u>>v>>w;
        add_edge(u,v,w);
        add_edge(v,u,w);
    }
    while(m--)
    {
        cin>>k;
        memset(vis,false,sizeof(vis));
        root=0;
        s=n;
        ans=0;
        get_root(1,0);
        solve(root);
        if(ans)
          cout<<"AYE\n";
        else
          cout<<"NAY\n";
    }
    return 0;
}

|