splay求调

P6136 【模板】普通平衡树(数据加强版)

Lovely_CCCyh___ @ 2024-04-16 17:37:15

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N=1101000;
struct splay_tree
{
    int ff,cnt,ch[2],val,size;
} t[N];
int root,tot;
void update(int x)
{
    t[x].size=t[t[x].ch[0]].size+t[t[x].ch[1]].size+t[x].cnt;
}
void rotate(int x)
{
    int y=t[x].ff;
    int z=t[y].ff;
    int k=(t[y].ch[1]==x);
    t[z].ch[(t[z].ch[1]==y)]=x;
    t[x].ff=z;
    t[y].ch[k]=t[x].ch[k^1];
    t[t[x].ch[k^1]].ff=y;
    t[x].ch[k^1]=y;
    t[y].ff=x;
    update(y);update(x);
}
void splay(int x,int s)
{
    while(t[x].ff!=s)
    {
        int y=t[x].ff,z=t[y].ff;
        if (z!=s)
            (t[z].ch[0]==y)^(t[y].ch[0]==x)?rotate(x):rotate(y);
        rotate(x);
    }
    if (s==0)
        root=x;
}
void find(int x)
{
    int u=root;
    if (!u)
        return ;
    while(t[u].ch[x>t[u].val] && x!=t[u].val)
        u=t[u].ch[x>t[u].val];
    splay(u,0);
}
void insert(int x)
{
    int u=root,ff=0;
    while(u && t[u].val!=x)
    {
        ff=u;
        u=t[u].ch[x>t[u].val];
    }
    if (u)
        t[u].cnt++;
    else
    {
        u=++tot;
        if (ff)
            t[ff].ch[x>t[ff].val]=u;
        t[u].ch[0]=t[u].ch[1]=0;
        t[tot].ff=ff;
        t[tot].val=x;
        t[tot].cnt=1;
        t[tot].size=1;
    }
    splay(u,0);
}
int Next(int x,int f)
{
    find(x);
    int u=root;
    if (t[u].val>x && f)
        return u;
    if (t[u].val<x && !f)
        return u;
    u=t[u].ch[f];
    while(t[u].ch[f^1])
        u=t[u].ch[f^1];
    return u;
}
void Delete(int x)
{
    int last=Next(x,0);
    int Net=Next(x,1);
    splay(last,0);
    splay(Net,last);
    int del=t[Net].ch[0];
    if (t[del].cnt>1)
    {
        t[del].cnt--;
        splay(del,0);
    }
    else
        t[Net].ch[0]=0;
}
int kth(int x)
{
    int u=root;
    while(t[u].size<x)
        return 0;
    while(1)
    {
        int y=t[u].ch[0];
        if (x>t[y].size+t[u].cnt)
        {
            x-=t[y].size+t[u].cnt;
            u=t[u].ch[1];
        }
        else if (t[y].size>=x)
            u=y;
        else
            return t[u].val;
    }
}
signed main()
{
    int n,m,ans=0,last=0,s=0;
    scanf("%lld%lld",&n,&m);
    insert(1e9);
    insert(-1e9);
    for(int i=1;i<=n;i++)
    {
        int x;
        scanf("%lld",&x);
        insert(x);
    }
    while(m--)
    {
        int opt,x;
        scanf("%lld%lld",&opt,&x);
        x=x^last;
        if (opt==1)
            insert(x);
        if (opt==2)
            Delete(x);
        if (opt==3)
        {
            insert(x);
            find(x);
            last=t[t[root].ch[0]].size;
             s^=last;
            Delete(x);
        }
        if (opt==4)
        {
            last=kth(x+1);
             s^=last;
        }
        if (opt==5)
        {
            last=t[Next(x,0)].val;
             s^=last;
        }
        if (opt==6)
        {
            last=t[Next(x,1)].val;
             s^=last;
        }
    }
    printf("%lld",s);
    return 0;
}

by kele7 @ 2024-04-16 19:37:46

倦疲


|