TLE #7 #8 #9 #10求助!

P3806 【模板】点分治 1

jzwzy @ 2024-09-29 22:10:17

#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 10010, M = N * 3, S = 10000010;

int n, m, Q;
int h[N], hs[N], e[M], w[M], ne[M], idx;
bool st[N];
int p[N], q[N];
bool f[S];

void add(int h[], int a, int b, int c)
{
    e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ;
}

int get_size(int u, int fa)
{
    if (st[u]) return 0;
    int res = 0;
    for (int i = h[u]; ~i; i = ne[i])
        if (e[i] != fa) res += get_size(e[i], u);
    return res;
}

int get_wc(int u, int fa, int tot, int& wc)
{
    if (st[u]) return 0;
    int sum = 1, ms = 0;
    for (int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i];
        if (j == fa) continue;
        int t = get_wc(j, u, tot, wc);
        ms = max(ms, t);
        sum += t;
    }
    ms = max(ms, tot - ms);
    if (ms <= tot / 2) wc = u;
    return sum;
}

void get_dist(int u, int fa, int dist, int& qt)
{
    if (st[u] || dist > m) return;
    q[qt ++ ] = dist;
    for (int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i];
        if (j == fa) continue;
        get_dist(j, u, dist + w[i], qt);
    }
}

int build(int u)
{
    if (st[u]) return 0;
    get_wc(u, -1, get_size(u, -1), u);
    st[u] = true;
    for (int i = h[u]; ~i; i = ne[i])
        if (!st[e[i]]) add(hs, u, build(e[i]), 0);
    return u;
}

bool calc(int u)
{
    if (st[u]) return false;
    st[u] = true;
    int pt = 0;
    for (int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i], qt = 0;
        get_dist(j, -1, w[i], qt);
        for (int k = 0; k < qt; k ++ )
        {
            if (q[k] == m || f[m - q[k]]) return true;
            p[pt ++ ] = q[k];
        }
        for (int k = 0; k < qt; k ++ ) f[q[k]] = true;
    }
    for (int i = 0; i < pt; i ++ ) f[p[i]] = false;
    for (int i = hs[u]; ~i; i = ne[i])
        if (calc(e[i])) return true;
    return false;
}

int main()
{
    memset(h, -1, sizeof h);
    memset(hs, -1, sizeof hs);
    scanf("%d%d", &n, &Q);
    for (int i = 0; i < n - 1; i ++ )
    {
        int a, b, c;
        scanf("%d%d%d", &a, &b, &c);
        add(h, a, b, c), add(h, b, a, c);
    }
    int root = build(1);
    while (Q -- )
    {
        memset(st, 0, sizeof st);
        scanf("%d", &m);
        if (calc(root)) puts("AYE");
        else puts("NAY");
    }
    return 0;
}

|