有大佬看看我写的kd-tree哪儿挂了嘛(哭

P1429 平面最近点对(加强版)

Macaron_lin @ 2019-08-13 01:44:35

#include <cstdio>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <queue>
using namespace std;
#define M ((L + R) >> 1)
const int maxn = 200000 + 10;
const double INF = 1e18;
int WD;
struct Point {
    int id;
    double x[2];
    bool operator < (const Point& tmp) const {return x[WD] < tmp.x[WD]; }
}p[maxn];
double sqr(double x) {return x * x; }
double dis2(Point a, Point b) {
    double res = 0;
    for (int i = 0; i < 2; i++) {
        res += sqr(a.x[i] - b.x[i]);
    }
    return res;
}
int ls[maxn], rs[maxn];
double maxv[maxn][2], minv[maxn][2];
Point tp[maxn]; int cnt;
void Merge(int id) {
    for (int i = 0; i < 2; i++) {
        maxv[id][i] = minv[id][i] = tp[id].x[i];
        if (ls[i]) {
            maxv[id][i] = max(maxv[id][i], maxv[ls[id]][i]);
            minv[id][i] = min(minv[id][i], minv[ls[id]][i]);
        }
        if (rs[i]) {
            maxv[id][i] = max(maxv[id][i], maxv[rs[id]][i]);
            minv[id][i] = min(minv[id][i], minv[rs[id]][i]);
        }
    }
}
int Build(int L, int R, int wd) {
    if (L > R)  return 0;
    int id = ++cnt;
    WD = wd;
    nth_element(p + L, p + M, p + R + 1);
    tp[id] = p[M];
    ls[id] = Build(L, M - 1, (wd + 1) % 2);
    rs[id] = Build(M + 1, R, (wd + 1) % 2);
    Merge(id);
    return id;
}
double GetMinDis(Point a, int id) {
    double res = 0;
    for (int i = 0; i < 2; i++) {
        if (a.x[i] > maxv[id][i] || a.x[i] < minv[id][i]) {
            res += min(sqr(a.x[i] - maxv[id][i]), sqr(a.x[i] - minv[id][i]));
        }
    }
    return sqrt(res);
}
double ans = INF;
void Query(Point a, int id) {
    if (!id)    return;
    double dis = sqrt(dis2(a, tp[id]));
    if (a.id != tp[id].id) {
        ans = min(ans, dis);
    }
    double disl = INF, disr = INF;
    if (ls[id]) disl = GetMinDis(a, ls[id]);
    if (rs[id]) disl = GetMinDis(a, rs[id]);
    if (disl < disr) {
        if (disl < ans) Query(a, ls[id]);
        if (disr < ans) Query(a, rs[id]);
    }
    else {
        if (disr < ans) Query(a, rs[id]);
        if (disl < ans) Query(a, ls[id]);
    }
}
int main() {
    int n;  scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        p[i].id = i;
        for (int j = 0; j < 2; j++) {
            scanf("%lf", &p[i].x[j]);
        }
    }
    int rt = Build(1, n, 0);
    for (int i = 1; i <= n; i++) {
        Query(p[i], rt);
    }
    printf("%.4f\n", ans);
    return 0;
}

|