求助,splay,re8个点

P6136 【模板】普通平衡树(数据加强版)

我爱杨帆 @ 2021-01-31 09:57:57

#include<bits/stdc++.h>
#define int long long
#define re register
#define inf 1e18
using namespace std;
const int sz=2e7+5;
inline int read()
{
    char ch;
    int r=0,f=1;
    ch=getchar();
    while((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
    if(ch=='-') f=-1,ch=getchar();
    while(ch>='0'&&ch<='9') r=r*10+ch-'0',ch=getchar();
    return r*f;
}
int root,tot;
struct BST
{       
    struct node
    {
        int val,ch[5],ff,size,cnt;
    }t[sz];
    int identify(int x)
    {   
        return t[t[x].ff].ch[1]==x ? 1:0;
    }   
    void update(int x)
    {   
        t[x].size=t[t[x].ch[0]].size+t[t[x].ch[1]].size+t[x].cnt;
    }   
    void rotate(int x)
    {   
        int y=t[x].ff,z=t[t[x].ff].ff,kx=identify(x),ky=identify(y);
        t[x].ff=z,t[z].ch[ky]=x;
        t[y].ff=x;
        t[y].ch[kx]=t[x].ch[kx^1],t[t[x].ch[kx^1]].ff=y;t[x].ch[kx^1]=y;
        update(y);update(x);
    }   
    void splay(int x,int goal)
    {
        while(t[x].ff!=goal)        
        {           
            int y=t[x].ff,z=t[t[x].ff].ff;
            if(z!=goal)     
                (t[z].ch[0]==y)^(t[y].ch[0]==x) ? rotate(x):rotate(y);
            rotate(x);      
        }       
        if(!goal)       root=x; 
    }
    void find(int val)
    {   
        int u=root;     
        if(!u) return;  
        while(t[u].ch[t[u].val<val]&&val!=t[u].val) u=t[u].ch[t[u].val<val];
        splay(u,0);
    }   
    void cre(int val,int ff)
    {
        int u=++tot;
        t[ff].ch[t[ff].val<val]=u;
        t[u].ff=ff;
        t[u].size=t[u].cnt=1;
        t[u].val=val;
    }
    void insert(int val)
    {   
        int u=root,ff=0;
        while(t[u].val!=val&&u)
            ff=u,u=t[u].ch[t[u].val<val];
        if(u)   t[u].cnt++; 
        else 
        {
            u=++tot;
            if(ff) 
                t[ff].ch[val>t[ff].val]=u;
            t[u].ch[0]=t[u].ch[1]=0;
            t[u].ff=ff,t[tot].val=val;
            t[u].size=t[u].cnt=1;   
        }
        splay(u,0); 
    }   
    int getnext(int x,int fla)
    {   
        find(x);
        int u=root;
        if(t[u].val>x&&fla) return u;
        if(t[u].val<x&&!fla) return u;
        u=t[u].ch[fla];
        while(t[u].ch[fla^1]) u=t[u].ch[fla^1];
        return u; 
    }   
    int dele(int x)
    {               
        int last=getnext(x,0);
        int next=getnext(x,1);
        splay(last,0);
        splay(next,last);
        if(t[t[next].ch[0]].cnt>1) 
        {
            t[t[next].ch[0]].cnt--;
            splay(t[next].ch[0],0);
        }
        else t[next].ch[0]=0;
    }               
    int getrankbyval(int val,int now)
    {               
        if(!now) return 1;  
        if(val==t[now].val) return t[t[now].ch[0]].size+1;
        if(val<t[now].val) return getrankbyval(val,t[now].ch[0]);
        return t[t[now].ch[0]].size+t[now].cnt+getrankbyval(val,t[now].ch[1]);
    }               
    int getvalbyrank(int ran,int now)
    {   
        if(!now) return inf;
        if(t[t[now].ch[0]].size>=ran) return getvalbyrank(ran,t[now].ch[0]);
        if(t[t[now].ch[0]].size+t[now].cnt>=ran) return t[now].val;
        return getvalbyrank(ran-t[t[now].ch[0]].size-t[now].cnt,t[now].ch[1]);  
    }   
}S;     
int ans=0,las=0;
signed main()
{
    int n=read(),m=read();
    S.insert(inf);S.insert(-inf);
    for(int i=1;i<=n;i++) 
        S.insert(read());   
    for(int i=1;i<=m;i++)
    {
        int op=read(),x=read()^las;
        switch (op) 
        {
            case 1: 
                S.insert(x);
                break;
            case 2:
                S.dele(x);
                break;
            case 3:
                las=S.getrankbyval(x,root)-1;
                ans^=las;
                break;
            case 4:
                las=S.getvalbyrank(x+1,root);
                ans^=las;
                break;
            case 5:
                las=S.t[S.getnext(x,0)].val;
                ans^=las;
                break;
            case 6:
                las=S.t[S.getnext(x,1)].val;
                ans^=las;
                break;              
        }
    }
    cout<<ans<<endl;
}
/*
5 1
1345345 4345345 4345345 55342 923454353454353
3 8
*/

|