求解KDT为什么我的不平衡重构跑不过暴力

P4148 简单题

lzyzs @ 2025-01-09 16:30:36

不平衡重构代码&测试

#include <bits/stdc++.h>
using namespace std;
const int N = 2e6 + 10, INF = 2e9;
const double bs = 0.9;
int n;
struct node {
    int nums[2], w;
};
struct trnode {
    int X[2], Y[2];
    int d, ls, rs, sum, val;
    void pr () {
        printf ("X[0]:%d Y[0]:%d X[1]:%d Y[1]:%d\n", X[0], Y[0], X[1], Y[1]);
    }
    node p;
    inline void clear (node x, int pd) {
        ls = rs = 0; d = pd;
        sum = 1, val = x.w;
        X[0] = X[1] = x.nums[0];
        Y[0] = Y[1] = x.nums[1];
        p = x;
    }
}tr[N];
inline int max (int a, int b) {return a < b ? b : a;}
inline int max (int a, int b, int c) {return max(a, max(b, c));}
inline int min (int a, int b) {return a > b ? b : a;}
inline int min (int a, int b, int c) {return min(a, min(b, c));}
inline bool cmp0 (node a, node b) {return a.nums[0] < b.nums[0];}
inline bool cmp1 (node a, node b) {return a.nums[1] < b.nums[1];}
int fw[N], cnt, id, top;
node stc[N];
int newnode () {
    if (cnt) return fw[cnt--];
    return ++id;
}
inline void update (int k) {
    tr[k].X[0] = min (tr[tr[k].ls].X[0], tr[tr[k].rs].X[0], tr[k].p.nums[0]);       
    tr[k].X[1] = max (tr[tr[k].ls].X[1], tr[tr[k].rs].X[1], tr[k].p.nums[0]);       
    tr[k].Y[0] = min (tr[tr[k].ls].Y[0], tr[tr[k].rs].Y[0], tr[k].p.nums[1]);       
    tr[k].Y[1] = max (tr[tr[k].ls].Y[1], tr[tr[k].rs].Y[1], tr[k].p.nums[1]);       
    tr[k].sum = tr[tr[k].ls].sum + tr[tr[k].rs].sum + 1;
    tr[k].val = tr[tr[k].ls].val + tr[tr[k].rs].val + tr[k].p.w;
}
void dfs (int k) {
    if (!k) return;
    stc[++top] = tr[k].p;
    fw[++cnt] = k;
    dfs (tr[k].ls); dfs (tr[k].rs);
}
int build (int l, int r, int d) {
    if (l > r) return 0; 
    int k = newnode(), mid = (l + r) >> 1;
    nth_element(stc + l, stc + mid, stc + r + 1, d ? cmp1 : cmp0);  
    tr[k].clear(stc[mid], d);
    tr[k].ls = build (l, mid - 1, d ^ 1);
    tr[k].rs = build (mid + 1, r, d ^ 1);
    update(k);
    return k;
}
inline void check (int &k, int d) {
    if (bs * tr[k].sum < max (tr[tr[k].ls].sum, tr[tr[k].rs].sum)) {
        top = 0; dfs (k);
        k = build(1, top, d);
    }
}
void insert (int &k, node p, int d) {
    if (!k) {
        k = newnode();
        tr[k].clear(p, d);
        return;
    }
    if (tr[k].p.nums[d] <= p.nums[d]) insert(tr[k].ls, p, d ^ 1);
    else insert(tr[k].rs, p, d ^ 1);
    update(k);
    check(k, d);
}
inline bool in (int a, int b, int c, int d, int A, int B, int C, int D) {
    return (A <= a && B <= b && C >= c && D >= d);
}
inline bool df (int a, int b, int c, int d, int A, int B, int C, int D) {
    return (c < A || d < B || C < a || D < b);
}
int query (int k, int a, int b, int c, int d) {
    if (!k) return 0;
    int res = 0;
    if (in(tr[k].X[0], tr[k].Y[0], tr[k].X[1], tr[k].Y[1], a, b, c, d)) return tr[k].val;
    if (df(a, b, c, d, tr[k].X[0], tr[k].Y[0], tr[k].X[1], tr[k].Y[1])) return 0;
    if (in(tr[k].p.nums[0], tr[k].p.nums[1], tr[k].p.nums[0], tr[k].p.nums[1], a, b, c, d)) res = tr[k].p.w;
    return res + query(tr[k].ls, a, b, c, d) + query(tr[k].rs, a, b, c, d);
}
signed main () {
    tr[0] = {INF, -INF, INF, -INF, 0, 0, 0, 0, 0};
    scanf("%d", &n);
    int rot = 0, lsan = 0;
    while (1) {
        int opt; scanf("%d", &opt);
        if (opt == 3) break;
        if (opt == 1) {
            int x, y, w;
            scanf("%d%d%d", &x, &y, &w);
            x ^= lsan; y ^= lsan; w ^= lsan; 
            insert(rot, {x, y, w}, 1);
        } else {
            int a, b, c, d;
            scanf("%d%d%d%d", &a, &b, &c, &d);
            a ^= lsan; b ^= lsan; c ^= lsan; d ^= lsan;
            cout << (lsan = query(rot, a, b, c, d)) << '\n';
        }
    }
    return 0;
}

不重构&测试

#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10, INF = 2e9;
const double bs = 2;
int n;
struct node {
    int nums[2], w;
};
struct trnode {
    int X[2], Y[2];
    int d, ls, rs, sum, val;
    void pr () {
        printf ("X[0]:%d Y[0]:%d X[1]:%d Y[1]:%d\n", X[0], Y[0], X[1], Y[1]);
    }
    node p;
    inline void clear (node x, int pd) {
        ls = rs = 0; d = pd;
        sum = 1, val = x.w;
        X[0] = X[1] = x.nums[0];
        Y[0] = Y[1] = x.nums[1];
        p = x;
    }
}tr[N];
inline int max (int a, int b) {return a < b ? b : a;}
inline int max (int a, int b, int c) {return max(a, max(b, c));}
inline int min (int a, int b) {return a > b ? b : a;}
inline int min (int a, int b, int c) {return min(a, min(b, c));}
inline bool cmp0 (node a, node b) {return a.nums[0] < b.nums[0];}
inline bool cmp1 (node a, node b) {return a.nums[1] < b.nums[1];}
int fw[N], cnt, id, top;
node stc[N];
int newnode () {
    if (cnt) return fw[cnt--];
    return ++id;
}
inline void update (int k) {
    tr[k].X[0] = min (tr[tr[k].ls].X[0], tr[tr[k].rs].X[0], tr[k].p.nums[0]);       
    tr[k].X[1] = max (tr[tr[k].ls].X[1], tr[tr[k].rs].X[1], tr[k].p.nums[0]);       
    tr[k].Y[0] = min (tr[tr[k].ls].Y[0], tr[tr[k].rs].Y[0], tr[k].p.nums[1]);       
    tr[k].Y[1] = max (tr[tr[k].ls].Y[1], tr[tr[k].rs].Y[1], tr[k].p.nums[1]);       
    tr[k].sum = tr[tr[k].ls].sum + tr[tr[k].rs].sum + 1;
    tr[k].val = tr[tr[k].ls].val + tr[tr[k].rs].val + tr[k].p.w;
}
void dfs (int k) {
    if (!k) return;
    stc[++top] = tr[k].p;
    fw[++cnt] = k;
    dfs (tr[k].ls); dfs (tr[k].rs);
}
int build (int l, int r, int d) {
    if (l > r) return 0; 
    int k = newnode(), mid = (l + r) >> 1;
    nth_element(stc + l, stc + mid, stc + r + 1, d ? cmp1 : cmp0);  
    tr[k].clear(stc[mid], d);
    tr[k].ls = build (l, mid - 1, d ^ 1);
    tr[k].rs = build (mid + 1, r, d ^ 1);
    update(k);
    return k;
}
void check (int &k, int d) {
    if (bs * tr[k].sum < max (tr[tr[k].ls].sum, tr[tr[k].rs].sum)) {
        top = 0; dfs (k);
        k = build(1, top, d);
    }
}
void insert (int &k, node p, int d) {
    if (!k) {
        k = newnode();
        tr[k].clear(p, d);
        return;
    }
    if (tr[k].p.nums[d] <= p.nums[d]) insert(tr[k].ls, p, d ^ 1);
    else insert(tr[k].rs, p, d ^ 1);
    update(k);
    check(k, d);
}
inline bool in (int a, int b, int c, int d, int A, int B, int C, int D) {
    return (A <= a && B <= b && C >= c && D >= d);
}
inline bool df (int a, int b, int c, int d, int A, int B, int C, int D) {
    return (c < A || d < B || C < a || D < b);
}
int query (int k, int a, int b, int c, int d) {
    if (!k) return 0;
    int res = 0;
    if (in(tr[k].X[0], tr[k].Y[0], tr[k].X[1], tr[k].Y[1], a, b, c, d)) return tr[k].val;
    if (df(a, b, c, d, tr[k].X[0], tr[k].Y[0], tr[k].X[1], tr[k].Y[1])) return 0;
    if (in(tr[k].p.nums[0], tr[k].p.nums[1], tr[k].p.nums[0], tr[k].p.nums[1], a, b, c, d)) res = tr[k].p.w;
    return res + query(tr[k].ls, a, b, c, d) + query(tr[k].rs, a, b, c, d);
}
signed main () {
    tr[0] = {INF, -INF, INF, -INF, 0, 0, 0, 0, 0};
    scanf("%d", &n);
    int rot = 0, lsan = 0;
    while (1) {
        int opt; scanf("%d", &opt);
        if (opt == 3) break;
        if (opt == 1) {
            int x, y, w;
            scanf("%d%d%d", &x, &y, &w);
            x ^= lsan; y ^= lsan; w ^= lsan; 
            insert(rot, {x, y, w}, 1);
        } else {
            int a, b, c, d;
            scanf("%d%d%d%d", &a, &b, &c, &d);
            a ^= lsan; b ^= lsan; c ^= lsan; d ^= lsan;
            cout << (lsan = query(rot, a, b, c, d)) << '\n';
        }
    }
    return 0;
}

by lzyzs @ 2025-01-09 16:37:45

https://www.luogu.com.cn/record/197511205 给错了是这个


|