Splay 求助!

P6136 【模板】普通平衡树(数据加强版)

Daniel2020 @ 2022-02-21 13:13:08

RT.
#include<bits/stdc++.h>
using namespace std;
const int N = 1e7+2;
int n,m,x,opt,tot,lst,ans,root;
int f[N],val[N],cnt[N],siz[N],son[N][2];
inline int read()
{
    int x = 0,f = 1;
    char c = getchar();
    while(c < '0' || c > '9') { if(c == '-') f = -1; c = getchar(); }
    while(c >= '0' && c <= '9') { x = (x<<3)+(x<<1)+c-'0'; c = getchar(); }
    return x*f;
}
inline void upd(int x) { siz[x] = siz[son[x][0]]+siz[son[x][1]]+cnt[x]; }
inline bool fnd(int x) { return x == son[f[x]][1]; }
inline void clear(int x)
{
    son[x][0] = 0;
    son[x][1] = 0;
    val[x] = 0;
    siz[x] = 0;
    cnt[x] = 0;
    f[x] = 0;
}
inline void rotate(int x)
{
    int y = f[x],z = f[y],a = fnd(x),b = a^1;
    son[y][a] = son[x][b];
    if(son[x][b]) f[son[x][b]] = y;
    son[x][b] = y;
    f[y] = x;
    f[x] = z;
    if(z) son[z][y == son[z][1]] = x;
    upd(y);
    upd(x);
}
inline void splay(int x)
{
    for(int i = f[x];i = f[x],i;rotate(x))
        if(f[i]) rotate(fnd(x) == fnd(i) ? i : x);
    root = x;
}
inline void ins(int x)
{
    if(!root)
    {
        val[++tot] = x;
        cnt[tot]++;
        root = tot;
        upd(root);
        return;
    }
    int cur = root,fa = 0;
    while(1)
    {
        if(val[cur] == x)
        {
            cnt[cur]++;
            upd(cur);
            upd(fa);
            splay(cur);
            break;
        }
        fa = cur;
        cur = son[cur][val[cur] < x];
        if(!cur)
        {
            val[++tot] = x;
            cnt[tot]++;
            f[tot] = fa;
            son[fa][val[fa] < x] = tot;
            upd(tot);
            upd(fa);
            splay(tot);
            break;
        }
    }
}
inline int rnk(int x)
{
    int res = 0,cur = root;
    while(cur)
    {
        if(x < val[cur]) { cur = son[cur][0]; continue; }
        res += siz[son[cur][0]];
        if(x == val[cur]) { splay(cur); return res+1; }
        res += cnt[cur];
        cur = son[cur][1]; 
    }
    if(!cur) return n;
}
inline int kth(int x)
{
    x = min(x,n);
    int cur = root;
    while(cur)
    {
        if(son[cur][0] && x <= siz[son[cur][0]]) cur = son[cur][0];
        else
        {
            x -= cnt[cur] + siz[son[cur][0]];
            if(x <= 0) { splay(cur); return val[cur]; }
            cur = son[cur][1];
        }
    }
}
inline int pre()
{
    int cur = son[root][0];
    if(!cur) return cur;
    while(son[cur][1]) cur = son[cur][1];
    splay(cur);
    return cur;
}
inline int nxt()
{
    int cur = son[root][1];
    if(!cur) return cur;
    while(son[cur][0]) cur = son[cur][0];
    splay(cur);
    return cur;
}
inline void del(int x)
{
    rnk(x);
    if(cnt[root] > 1) { cnt[root]--; upd(root); return; }
    if(!son[root][0] && !son[root][1]) { clear(root); root = 0; return; }
    if(!son[root][0])
    {
        int cur = root;
        root = son[root][1];
        f[root] = 0;
        clear(cur);
        return;
    }
    if(!son[root][1])
    {
        int cur = root;
        root = son[root][0];
        f[root] = 0;
        clear(cur);
        return;
    }
    int cur = root,k = pre();
    f[son[cur][1]] = k;
    son[k][1] = son[cur][1];
    clear(cur);
    upd(root);
}
int main()
{
//  freopen("P6136_2.in","r",stdin);
    n = read();
    m = read();
    for(int i = 1;i <= n;i++)
    {
        x = read();
        ins(x);
    }
    for(int i = 1;i <= m;i++)
    {
        opt = read();
        x = read();
        x ^= lst;
        if(opt == 1) ins(x),n++;
        if(opt == 2) del(x),n--;
        if(opt == 3) lst = rnk(x);
        if(opt == 4) lst = kth(x);
        if(opt == 5)
        {
            ins(x);
            lst = val[pre()];
            del(x);
        }
        if(opt == 6)
        {
            ins(x);
            lst = val[nxt()];
            del(x);
        }
        if(opt > 2) ans ^= lst;
    }
    printf("%d",ans);
    return 0;
}

|