splay90求助

P6136 【模板】普通平衡树(数据加强版)

kemkra @ 2021-03-16 23:15:56

提交记录

#include <cstdio>
#include <cstring>

using namespace std;

const int N = 2e6;
const int INF = 0x7fffffff;

int n, m, last, ans, tot, root;
int val[N], cnt[N], size[N], fa[N], ch[N][2];

int newnode(int x) {
    val[++tot] = x;
    cnt[tot] = size[tot] = 1;
    return tot;
}

void update(int x) {
    size[x] = size[ch[x][0]] + size[ch[x][1]] + cnt[x];
}

bool id(int x) {
    return x == ch[fa[x]][1];
}

void connect(int x, int y, bool k) {
    ch[x][k] = y;
    fa[y] = x;
}

void rotate(int x) {
    int y = fa[x], z = fa[y];
    bool k = id(x);
    connect(z, x, id(y));
    connect(y, ch[x][k ^ 1], k);
    connect(x, y, k ^ 1);
    update(y);
    update(x);
}

void splay(int x, int to) {
    while (fa[x] != to) {
        int y = fa[x];
        if (fa[y] != to) id(x) ^ id(y) ? rotate(x) : rotate(y);
        rotate(x);
    }
    if (to == 0) root = x;
}

void find(int x) {
    int u = root;
    while (val[u] != x && ch[u][x > val[u]]) u = ch[u][x > val[u]];
    splay(u, 0);
}

void insert(int x) {
    if (!root) {
        root = newnode(x);
        return;
    }
    int u = root;
    while (val[u] != x && ch[u][x > val[u]]) u = ch[u][x > val[u]];
    if (val[u] == x) {
        cnt[u]++;
        splay(u, 0);
    } else {
        connect(u, newnode(x), x > val[u]);
        splay(ch[u][x > val[u]], 0);
    }
}

void erase(int x) {
    if (!root) return;
    find(x);
    if (!root || val[root] != x) return;
    if (cnt[root] > 1) {
        cnt[root]--;
        size[root]--;
        return;
    }
    int u = ch[root][0];
    if (!u) {
        root = ch[root][1];
        fa[root] = 0;
        return;
    }
    while (ch[u][1]) u = ch[u][1];
    splay(u, root);
    connect(u, ch[root][1], 1);
    root = u;
    fa[root] = 0;
    update(root);
}

int rank(int x) {
    find(x);
    int ls = size[ch[root][0]];
    return (val[root] < x ? ls + cnt[root] : ls) + 1;
}

int kth(int x) {
    int u = root;
    while (1) {
        if (x <= size[ch[u][0]]) u = ch[u][0];
        else if (x > size[ch[u][0]] + cnt[u]) {
            x -= size[ch[u][0]] + cnt[u];
            u = ch[u][1];
        } else break;
    }
    splay(u, 0);
    return val[u];
}

int pre(int x) {
    find(x);
    if (val[root] < x) return val[root];
    int u = ch[root][0];
    if (!u) return -INF;
    while (ch[u][1]) u = ch[u][1];
    splay(u, 0);
    return val[u];
}

int nxt(int x) {
    find(x);
    if (val[root] > x) return val[root];
    int u = ch[root][1];
    if (!u) return INF;
    while (ch[u][0]) u = ch[u][0];
    splay(u, 0);
    return val[u];
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1, a; i <= n; i++) {
        scanf("%d", &a);
        insert(a);
    }
    for (int i = 1, opt, x; i <= m; i++) {
        scanf("%d%d", &opt, &x);
        x ^= last;
        if (opt == 1) insert(x);
        if (opt == 2) erase(x);
        if (opt == 3) {
            last = rank(x);
            ans ^= last;
        }
        if (opt == 4) {
            last = kth(x);
            ans ^= last;
        }
        if (opt == 5) {
            last = pre(x);
            ans ^= last;
        }
        if (opt == 6) {
            last = nxt(x);
            ans ^= last;
        }
    }
    printf("%d", ans);
    return 0;
}

by zimujun @ 2021-03-17 07:33:07

@PrHacker235 输入数据太大了,您用快读卡一下常


by 尤斯蒂亚 @ 2021-03-17 10:25:50

@PrHacker235 随机转节点,这是故意卡的数据,会让树退化,就不平衡了。我之前也是这个问题


by 尤斯蒂亚 @ 2021-03-17 10:26:55

Splay(rand()%tot+1);


by Twig @ 2021-04-13 07:41:36

@PrHacker235

void rotate(int x) {
    int y = fa[x], z = fa[y];
    bool k = id(x);
    connect(z, x, id(y));
    connect(y, ch[x][k ^ 1], k);
    connect(x, y, k ^ 1);
    update(y);
    update(x);
}

改成

void rotate(int x) {
    int y = fa[x], z = fa[y];
    bool k = id(x), kk = id(y);
    connect(x, y, k ^ 1);
    connect(y, ch[x][k ^ 1], k);
     fa[x] = z;
    if(z){connect(z, x, kk;}
    update(y);
    update(x);
}

就能A这道题了


by Twig @ 2021-04-13 07:57:17

@PrHacker235

抱歉,刚才改错了/jk,应该改成下面这样

void rotate(int x) {
    int y = fa[x], z = fa[y];
    bool k = id(x), kk = id(y);
    connect(y, ch[x][k ^ 1], k);
    connect(x, y, k ^ 1);
    fa[x] = z;
    if(z){connect(z, x, kk);}
    update(y);
    update(x);
}

by kemkra @ 2021-04-14 18:56:18

@Jair314 感谢大佬点拨!


by kemkra @ 2021-08-28 20:02:38

挖坟,警示后人

一定要注意 rotate顺序

如果 node 本身就是 root,一定不要对 0 号节点进行 connect 操作


|