求助点分治, TLE

P3806 【模板】点分治 1

sinsop90 @ 2022-02-12 10:54:02

#include <bits/stdc++.h>
#define maxn 500005
using namespace std;
int n, head[maxn], tot, f[maxn], sz[maxn], S, rt, m, ans[maxn], id[maxn], cnt;
bool vis[maxn];
struct node {
    int v, pre, w;
}e[maxn << 1];
struct nod {
    int dis, belong;
}dep[maxn];
void add(int u, int v, int w) {
    e[++tot].v = v;
    e[tot].w = w;
    e[tot].pre = head[u];
    head[u] = tot;
}
void getroot(int now, int fa) {
    sz[now] = 1, f[now] = 0;
    for(int i = head[now];i;i = e[i].pre) {
        int v = e[i].v;
        if(v != fa && !vis[v]) {
            getroot(v, now);
            sz[now] += sz[v];
            f[now] = max(f[now], sz[v]);
        }
    }
    f[now] = max(f[now], S - sz[now]);
//  cout << rt << " " << f[rt] << " " << now << " " << f[now] << endl;
    if(f[now] < f[rt] || rt == 0) rt = now;
}
void getdep(int now, int fa, int x, int last) {
    dep[++cnt].dis = last;
    dep[cnt].belong = x;
    for(int i = head[now];i;i = e[i].pre) {
        int v = e[i].v;
        if(v != fa && !vis[v]) {
            getdep(v, now, x, last + e[i].w);
        }
    }
}
bool cmp(nod a, nod b) {
    return a.dis < b.dis;
}
void getsum(int now) {
    cnt = 0;
    for(int i = head[now];i;i = e[i].pre) {
        int v = e[i].v;
        if(!vis[v]) getdep(v, now, v, e[i].w);
    }
    dep[++cnt] = (nod){0, 0};
    sort(dep + 1, dep + 1 + cnt, cmp);
//  cout << now << endl;
//  for(int i = 1;i <= cnt;i++) cout << dep[i].dis << " ";
//  cout << endl;
    for(int i = 1;i <= m;i++) {
        if(ans[i]) continue;
        int l = 1, r = cnt;
        while(l < cnt && dep[l].dis + dep[r].dis < id[i]) l ++;
        while(l < cnt) {
//          cout << id[i] << " " << dep[l].dis << endl;
            if(id[i] - dep[l].dis < dep[l].dis) break;

            int ll = 1, rr = cnt, ansp;
            while(ll <= rr) {
                int mid = (ll + rr) >> 1;
                if(dep[mid].dis > id[i] - dep[l].dis) rr = mid - 1;
                else ll = mid + 1, ansp = mid;
            }

            r = ansp;
//          cout << ansp << endl;
            while(r <= tot && dep[l].dis + dep[r].dis == id[i] && dep[l].belong == dep[r].belong) r ++;
            if(dep[l].dis + dep[r].dis == id[i]) {
                ans[i] = 1;
                break;
            }
            l ++;
        }
    }
}
void solve(int x) {
    vis[x] = true;
    getsum(x);
    for(int i = head[x];i;i = e[i].pre) {
        int v = e[i].v;
        if(!vis[v]) {
            rt = 0;
            S = sz[v];
            getroot(x, 0);
            solve(v);
        }
    }
}
int main() {
    scanf("%d%d", &n, &m);
    for(int i = 1;i <= n - 1;i++) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        add(u, v, w);
        add(v, u, w);
    }
    for(int i = 1;i <= m;i++) scanf("%d", &id[i]);
    S = n, rt = 0;
    getroot(1, 0);
    solve(rt);
    for(int i = 1;i <= m;i++) {
        if(!ans[i]) printf("NAY\n");
        else printf("AYE\n");
    }
}

|