Splay 5,6操作大问题

P6136 【模板】普通平衡树(数据加强版)

firstlight @ 2023-08-30 12:14:42

5,6操作进行旋转后互相影响

// Problem: P6136 【模板】普通平衡树(数据加强版)
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P6136
// Memory Limit: 89 MB
// Time Limit: 3000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int N=1000010, INF = 2e9;

int n,m;
int last;
int ans;
int L, R;
struct Node{
    int s[2], p, v;
    int size;

    void init(int _v, int _p)
    {
        v = _v, p = _p;
        size = 1;
    }
}tr[N];
int root, idx;

void pushup(int u)
{
    tr[u].size = tr[tr[u].s[1]].size + tr[tr[u].s[0]].size + 1;
}

void rotate(int x)
{
    int y = tr[x].p, z = tr[y].p;
    int k = tr[y].s[1] == x;
    tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
    tr[y].s[k] = tr[x].s[k ^ 1],tr[tr[x].s[k ^ 1]].p = y;
    tr[x].s[k ^ 1] = y, tr[y].p = x;
    pushup(y), pushup(x);
}

void splay(int x, int k)
{
    while(tr[x].p != k)
    {
        int y = tr[x].p, z = tr[y].p;
        if(z != k) 
            if((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
            else rotate(y);
        rotate(x);
    }
    if(!k) root = x;
}

int insert(int x)
{
    int u = root, p = 0;
    while(u) p = u, u = tr[u].s[x > tr[u].v];
    u = ++idx;
    if(p) tr[p].s[x > tr[p].v] = u;
    tr[u].init(x, p);
    splay(u, 0);
    return u;
}

int getrank(int v)
{
    int u = root, res = 0;
    while(u)
    {
        if(tr[tr[u].s[0]].v >= v) u = tr[u].s[0];
        else if(tr[u].v >= v)
        {
            res += tr[tr[u].s[0]].size + 1;
            return res;
        }
        else res += tr[tr[u].s[0]].size + 1, u = tr[u].s[1]; 
    }
    return -1;
}

int getk(int k)
{
    int u = root;
    while(u) 
    {
        if(tr[tr[u].s[0]].size >= k) u = tr[u].s[0];
        else if(tr[tr[u].s[0]].size + 1 == k) return u;
        else k -= tr[tr[u].s[0]].size + 1, u = tr[u].s[1];
    }
    return -1;
}

void clean(int k)
{
    k = getrank(k);
    int l = getk(k - 1), r = getk(k + 1);
    if(l == -1) l = L;
    if(r == -1) r = R;
    splay(l, 0), splay(r, l);
    tr[r].s[0] = 0;
    pushup(r), pushup(l);
}

void print(int u)
{
    if(tr[u].s[0]) print(tr[u].s[0]);
    if(tr[u].v != INF && tr[u].v != -INF) printf("%d ", tr[u].v);
    if(tr[u].s[1]) print(tr[u].s[1]);
}

int dfs_next(int u)
{
    while(tr[u].v == tr[tr[u].s[1]].v)
    {
        u = tr[u].s[1];
        splay(u, 0);
    } 
    return tr[tr[u].s[1]].v;
}

int dfs_pre(int u)
{
    while(tr[u].v == tr[tr[u].s[0]].v)
    {
        u = tr[u].s[0];
        splay(u, 0);
    } 
    return tr[tr[u].s[0]].v;
}

int main()
{
    scanf("%d%d", &n, &m);
    L = insert(-INF), R = insert(INF);
    for(int i = 1; i <= n; i ++ ) 
    {
        int a;
        scanf("%d", &a);
        insert(a);
    }
    int op, k;
    while(m -- )
    {
        // print(root);
        // printf("\n");
        scanf("%d%d", &op, &k);
        // k ^= last;
        // printf("%d\n", k);
        if(op == 1) insert(k);
        else if(op == 2) clean(k);
        else if(op == 3) 
        {
            last = getrank(k) - 1;
            ans ^= last;
            // printf("%d <3<%d>\n", last, k);
        }
        else if(op == 4)
        {
            last = tr[getk(k + 1)].v;
            ans ^= last;
            // printf("%d <4<%d>\n", last, k);
        }
        else if(op == 5)
        {
            insert(k);
            int rank = getrank(k);
            splay(getk(rank + 1), 0);
            // printf("%d\n",root);
            last = dfs_pre(getk(rank + 1));
            ans ^= last;
            // printf("%d <5<%d>\n", last, k);
            clean(k);
        }
        else 
        {
            insert(k);
            int rank = getrank(k);
            splay(getk(rank - 1), 0);
            // printf("%d\n",root);
            last = dfs_next(getk(rank - 1));
            ans ^= last;
            // printf("%d <6<%d>\n", last, k);
            clean(k);
        }
    }
    // printf("%d", ans);
    return 0;
}

|