可爱萌新求助 80 分 KD-Tree

P4148 简单题

Mogeko @ 2022-01-10 14:38:54

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define RN 500005

typedef int I;
typedef char C;
typedef long long L;

#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define SWAP(T, a, b) { T t = a; a = b; b = t; }

// K-D Tree

typedef struct
{
    I n[2];
}
KDInfo;

typedef struct
{
    KDInfo pos;
    KDInfo lbnd;
    KDInfo rbnd;
    I      val;
    I      sum;
    I      cnt;
    I      dim;
    I      ch[2];
}
KDNode;

KDNode kdpool[RN];
I      kdcnt;

#define kdpos(x)  kdpool[x].pos
#define kdlbnd(x) kdpool[x].lbnd
#define kdrbnd(x) kdpool[x].rbnd
#define kdval(x)  kdpool[x].val
#define kdsum(x)  kdpool[x].sum
#define kdcnt(x)  kdpool[x].cnt
#define kddim(x)  kdpool[x].dim
#define kdlch(x)  kdpool[x].ch[0]
#define kdrch(x)  kdpool[x].ch[1]

KDInfo kdbuf[RN];
I      kdind[RN];

I chooseKD(I l, I r)
{
    I mx = 0, mxvar = 0;
    for (I i = 0; i < 2; i++)
    {
        L sum = 0, sum2 = 0;
        for (I j = l; j <= r; j++)
        {
            sum += 1ll * kdbuf[j].n[i];
            sum2 += 1ll * kdbuf[j].n[i] * kdbuf[j].n[i];
        }
        if ((r - l + 1) * sum2 - sum * sum > mxvar) 
            mxvar = (r - l + 1) * sum2 - sum * sum, mx = i;
    }
    return mx;
}

I sortKD(I l, I r, I k)
{
    I mx = chooseKD(l, r);
    while (1)
    {
        I i = l, j = r;
        I pivot = kdbuf[l + rand() % (r - l + 1)].n[mx];
        do
        {
            while (kdbuf[i].n[mx] < pivot) i++;
            while (pivot < kdbuf[j].n[mx]) j--;
            if (i <= j)
            {
                SWAP(KDInfo, kdbuf[i], kdbuf[j]);
                SWAP(I, kdind[i], kdind[j]);
                i++, j--;
            }
        }
        while (i <= j);
        if (k >= i) l = i;
        else if (k <= j) r = j;
        else break;
    }
    return mx;
}

static inline void upKD(I x)
{
    kdcnt(x) = kdcnt(kdlch(x)) + 1 + kdcnt(kdrch(x));
    kdsum(x) = kdsum(kdlch(x)) + kdval(x) + kdsum(kdrch(x));
}

static inline void refreshKD(I x)
{
    for (I i = 0; i < 2; i++)
    {
        I val1 = kdpos(x).n[i], val2 = kdpos(x).n[i];
        if (kdlch(x))
        {
            val1 = MIN(val1, kdlbnd(kdlch(x)).n[i]);
            val2 = MAX(val2, kdrbnd(kdlch(x)).n[i]);
        }
        if (kdrch(x))
        {
            val1 = MIN(val1, kdlbnd(kdrch(x)).n[i]);
            val2 = MAX(val2, kdrbnd(kdrch(x)).n[i]);
        }
        kdlbnd(x).n[i] = val1;
        kdrbnd(x).n[i] = val2;
    }
}

I buildKD(I l, I r)
{
    if (l > r) return 0;
    I mid = l + ((r - l) >> 1);
    I dim = sortKD(l, r, mid);
    I x = kdind[mid];

    kddim(x) = dim;
    kdpos(x) = kdbuf[mid];
    kdlch(x) = buildKD(l, mid - 1);
    kdrch(x) = buildKD(mid + 1, r);
    refreshKD(x);
    upKD(x);
    return x;
}

void piaKD(I x, I l)
{
    if (!x) return;
    piaKD(kdlch(x), l);
    kdbuf[l + kdcnt(kdlch(x))] = kdpos(x);
    kdind[l + kdcnt(kdlch(x))] = x;
    piaKD(kdrch(x), l + kdcnt(kdlch(x)) + 1);
}

I maintainKD(I x)
{
    if (kdcnt(kdlch(x)) > kdcnt(kdrch(x)) * 4 + 4
     || kdcnt(kdrch(x)) > kdcnt(kdlch(x)) * 4 + 4)
    {
        piaKD(x, 1);
        return buildKD(1, kdcnt(x));
    }
    return x;
}

static inline C insideKD(KDInfo x, KDInfo y, KDInfo l, KDInfo r)
{
    I ret = 1;
    ret &= x.n[0] >= l.n[0] && y.n[0] <= r.n[0];
    ret &= x.n[1] >= l.n[1] && y.n[1] <= r.n[1];
    return ret;
}

I addKD(I x, KDInfo pos, I val)
{
    if (!x || insideKD(pos, pos, kdpos(x), kdpos(x)))
    {
        if (!x) x = ++kdcnt, kdpos(x) = kdlbnd(x) = kdrbnd(x) = pos;
        kdval(x) += val, kdsum(x) += val, kdcnt(x) = 1;
        return x;
    }
    if (kdpos(x).n[kddim(x)] > pos.n[kddim(x)])
        kdlch(x) = addKD(kdlch(x), pos, val);
    else
        kdrch(x) = addKD(kdrch(x), pos, val);
    refreshKD(x);
    upKD(x);
    return x;
}

I queryKD(I x, KDInfo lbnd, KDInfo rbnd)
{
    if (insideKD(kdlbnd(x), kdrbnd(x), lbnd, rbnd))
        return kdsum(x);
    I l = kdlch(x), r = kdrch(x), sum = 0;
    if (insideKD(kdpos(x), kdpos(x), lbnd, rbnd))
        sum += kdval(x);
    if (l && lbnd.n[kddim(x)] <= kdpos(x).n[kddim(x)])
        sum += queryKD(l, lbnd, rbnd);
    if (r && rbnd.n[kddim(x)] >= kdpos(x).n[kddim(x)])
        sum += queryKD(r, lbnd, rbnd);
    return sum;
}

void debugKD(I x, I dep)
{
    if (!x || dep > 10) return;
    debugKD(kdlch(x), dep + 1);
    debugKD(kdrch(x), dep + 1);
    printf("%d %d %d %d %d %d %d %d %d %d %d\n", kdcnt(x), kdval(x), kdsum(x), kdlch(x), kdrch(x), kdpos(x).n[0], kdpos(x).n[1], kdlbnd(x).n[0], kdlbnd(x).n[1], kdrbnd(x).n[0], kdrbnd(x).n[1]);
}

// Main

int main(void)
{
    I n, root = 0, lastans = 0;
    scanf("%d", &n);
    while (1)
    {
        I opr;
        scanf("%d", &opr);
        if (opr == 1)
        {
            I a, b, c;
            scanf("%d%d%d", &a, &b, &c);
            a ^= lastans, b ^= lastans, c ^= lastans;
            root = addKD(root, (KDInfo){{a, b}}, c);
            root = maintainKD(root);
        }
        else if (opr == 2)
        {
            I a, b, c, d;
            scanf("%d%d%d%d", &a, &b, &c, &d);
            a ^= lastans, b ^= lastans, c ^= lastans, d ^= lastans;
            printf("%d\n", lastans = queryKD(root, (KDInfo){{a, b}}, (KDInfo){{c, d}}));
        }
        else break;
        //debugKD(root, 0);
    }
    return 0;
}

by Ckger @ 2022-01-10 16:43:34

卡常!


by 行吟啸九州 @ 2022-01-19 21:59:07

现在弄好了吗


|