点分治 WA on 1, 3 求助

P3806 【模板】点分治 1

Illus1onary_Real1ty @ 2024-06-25 09:43:24

#include <bits/stdc++.h>
#define int long long
#define mp make_pair
#define fi first
#define se second
using namespace std;

const int N = 1e4 + 10;
int n, m, tot = 0, rt = 0, vis[N], sz[N], mx[N];

struct Edge{
    int to, w;

    Edge() {}

    Edge(int to_, int w_){
        to = to_;
        w = w_; 
    }
};

vector<Edge> g[N];

int q[110], ans[110];

pair<int, int> dis[N];

void Get_Root(int u, int fno, int SZ){
    sz[u] = 1;
    mx[u] = 0;

    for (auto e : g[u]){
        int to = e.to;
        int w = e.w;

        if (to == fno || vis[to])
            continue;

        Get_Root(to, u, SZ);

        mx[u] = max(mx[u], sz[to]);
        sz[u] += sz[to];
    }

    mx[u] = max(mx[u], SZ - sz[u]);

    if (!rt && mx[u] <= SZ / 2)
        rt = u;
}

void Get_Dis(int u, int fno, int d, int from){
    dis[++tot] = mp(d, from);

    for (auto e : g[u]){
        int to = e.to;
        int w = e.w;

        if (to == fno || vis[to])
            continue;

        Get_Dis(to, u, d + w, from);
    }
}

void Calc(int u){
    tot = 0;
    dis[++tot] = mp(0, u);

    for (auto e : g[u]){
        int to = e.to;
        int w = e.w;

        if (vis[to])
            continue;

        Get_Dis(to, u, w, to);
    }

    sort(dis + 1, dis + tot + 1);

    for (int i = 1; i <= m; i++){
        int l = 1, r = tot;

        if (ans[i])
            continue;

        while (l < r){
            int dl = dis[l].fi;
            int dr = dis[r].fi;

            int fl = dis[l].se;
            int fr = dis[r].se;

            if (dl + dr > q[i])
                r--;
            else if (dl + dr < q[i])
                l++;
            else if (fl == fr){
                if (dis[r-1].fi == dis[r].fi)
                    r--;
                else    l++;
            }else{
                ans[i] = 1;
                break;
            }
        }
    }
}

void Solve(int u){
    vis[u] = 1;
    Calc(u);

    for (auto e : g[u]){
        int to = e.to;
        if (vis[to])
            continue;

        rt = 0;
        Get_Root(to, 0, sz[to]);

        Solve(rt);
    }
}

signed main(){
    cin >> n >> m;
    for (int i = 1, u, v, w; i < n; i++){
        cin >> u >> v >> w;

        g[u].push_back(Edge(v, w));
        g[v].push_back(Edge(u, w));
    }

    for (int i = 1; i <= m; i++){
        cin >> q[i];

        if (!q[i])
            ans[i] = 1;
    }

    rt = 0;
    Get_Root(1, 0, n);
    Solve(rt);

    for (int i = 1; i <= m; i++){
        if (ans[i])
            cout << "AYE" << endl;
        else    cout << "NAY" << endl;
    }

    return 0;
} 

|