救救我WA 4&&6 求助

P3806 【模板】点分治 1

Zkl21 @ 2023-06-07 22:06:36

#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 1e4 + 10, M = N << 1, S = 1e7 + 10;
int n, m;
int h[N], e[M], ne[M], w[M], idx;
bool st[N];
int q[N], p[N], k[110];
bool f[S], ans[110];
void add(int a, int b, int c)
{
    e[++idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx;
}
int get_size(int u, int fa)
{
    if (st[u])
        return 0;
    int res = 1;
    for (int i = h[u], j = e[i]; i; i = ne[i], j = e[i])
        if (j != fa)
            res += get_size(j, u);
    return res;
}
int get_wc(int u, int fa, int tot, int &wc)
{
    if (st[u])
        return 0;
    int sum = 1, ms = 0;
    for (int i = h[u], j = e[i]; i; i = ne[i], j = e[i])
        if (j != fa)
        {
            int t = get_wc(j, u, tot, wc);
            ms = max(ms, t);
            sum += t;
        }
    ms = max(ms, tot - sum);
    if (ms <= tot / 2)
        wc = u;
    return sum;
}
void get_dist(int u, int fa, int dist, int &qt)
{
    if (st[u])
        return;
    q[qt++] = dist;
    for (int i = h[u], j = e[i]; i; i = ne[i], j = e[i])
        if (j != fa)
            get_dist(j, u, dist + w[i], qt);
}
void calc(int u)
{
    if (st[u])
        return;
    get_wc(u, -1, get_size(u, -1), u); // 找重心
    st[u] = 1;
    int pt = 0;
    f[0] = 1;
    for (int i = h[u], j = e[i]; i; i = ne[i], j = e[i])
    {
        int qt = 0;
        get_dist(j, u, w[i], qt);
        for (int l = 0; l < qt; l++)
        {
            auto t = q[l];
            if (t > 1e7)
                continue;
            p[pt++] = t;
            f[t] = 1;
        }
    }
    for (int i = 1; i <= m; i++)
        if (!ans[i])
        {
            if (f[k[i]])
            {
                ans[i] = 1;
                continue;
            }
            for (int j = 0; j < pt; j++)
                if (k[i] >= p[j] && p[j] * 2 != k[i] && f[k[i] - p[j]])
                {
                    ans[i] = 1;
                    break;
                }
        }
    for (int i = 0; i < pt; i++)
        if (p[i] <= 1e7)
            f[p[i]] = 0;
    for (int i = h[u], j = e[i]; i; i = ne[i], j = e[i])
        calc(j);
}
int main()
{
#ifndef Luogu // 记得删'n'
    freopen("E:\\in and out\\in.in", "r", stdin);
    freopen("E:\\in and out\\out.out", "w", stdout);
#endif
    ios::sync_with_stdio(0);
    cin >> n >> m;
    for (int i = 1; i < n; i++)
    {
        int a, b, c;
        cin >> a >> b >> c;
        add(a, b, c), add(b, a, c);
    }
    for (int i = 1; i <= m; i++)
        cin >> k[i];
    calc(1);
    for (int i = 1; i <= m; i++)
        puts(ans[i] ? "AYE" : "NAY");
    return 0;
}

by Night_sea_64 @ 2023-06-07 22:09:17

正好,我也想求调这题,T 了。。

厚颜无耻地借一下楼(

#include<iostream>
#include<vector>
#include<cstring>
using namespace std;
int n,m,k,d[10010],sz[10010],maxpart[10010];
struct edge{int x,w;bool f;};
vector<edge>v[10010],v2[10010];
int a[10010],cur;
bool flag[10000010],ans;
void dfs1(int x,int last)
{
    sz[x]=1;
    for(int i=0;i<v[x].size();i++)if(v[x][i].f)
        if(v[x][i].x!=last)
        {
            dfs1(v[x][i].x,x);
            sz[x]+=sz[v[x][i].x];
            maxpart[x]=max(maxpart[x],sz[v[x][i].x]);
        }
}
int find(int x)
{
    for(int i=1;i<=n;i++)
        maxpart[i]=sz[i]=0;
    dfs1(x,0);
    int minn=1e9,minid=0;
    for(int i=1;i<=n;i++)
    {
        maxpart[i]=max(maxpart[i],sz[x]-sz[i]);
        if(maxpart[i]<minn)
            minn=maxpart[i],minid=i;
    }
    return minid;
}
void dfs2(int x,int last)
{
    if(d[x]>k)return;
    if(flag[k-d[x]])ans=1;
    if(ans)return;
    for(int i=0;i<v[x].size();i++)
        if(v[x][i].x!=last)
        {
            d[v[x][i].x]=d[x]+v[x][i].w;
            dfs2(v[x][i].x,x);
        }
}
void dfs3(int x,int last)
{
    if(d[x]>k)return;
    if(!flag[d[x]])
    {
        flag[d[x]]=1;
        a[++cur]=d[x];
    }
    for(int i=0;i<v[x].size();i++)
        if(v[x][i].x!=last)
            dfs3(v[x][i].x,x);
}
void solve(int x)
{
    x=find(x);
    memset(d,0,sizeof(d));
    flag[0]=1;
    a[++cur]=0;
    for(int i=0;i<v[x].size();i++)if(v[x][i].f)
    {
        d[v[x][i].x]=v[x][i].w;
        dfs2(v[x][i].x,x);
        dfs3(v[x][i].x,x);
    }
    for(int i=1;i<=cur;i++)
        flag[a[i]]=0;
    cur=0;
    if(ans)return;
    v2[x]=v[x];
    for(int i=1;i<=n;i++)
        for(int j=0;j<v[i].size();j++)
            if(i==x||v[i][j].x==x)v[i][j].f=0;
    for(int i=0;i<v2[x].size();i++)if(v2[x][i].f)
        solve(v[x][i].x);
}
int main()
{
    cin>>n>>m;
    for(int i=1;i<n;i++)
    {
        int x,y,w;
        cin>>x>>y>>w;
        v[x].push_back({y,w,1});
        v[y].push_back({x,w,1});
    }
    while(m--)
    {
        ans=0;
        for(int i=1;i<=n;i++)
            for(int j=0;j<v[i].size();j++)
                v[i][j].f=1;
        cin>>k;
        solve(1);
        cout<<(ans?"AYE":"NAY")<<endl;
    }
    return 0;
}

|