Splay 30pts求助

P6136 【模板】普通平衡树(数据加强版)

火羽白日生 @ 2021-04-18 15:35:10

rt,普通平衡树小改了一下连样例都没过

#include <bits/stdc++.h>
#define LL long long
#define ull unsigned long long
#define rint register int

using namespace std;

inline int read(){
    int w=0,f=1; char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-') f=-1; ch=getchar();}
    while(ch>='0'&&ch<='9'){w=(w<<3)+(w<<1)+(ch^48); ch=getchar();}
    return w*f;
}

const int maxn=1e5+5,maxm=1e6+5,inf=0x7fffffff;

int n,m,root,tot;
int a[maxn];
struct node{
    int son[2],fa,val,cnt,size;
}t[maxn+maxm];
inline int get(int x){
    return x==t[t[x].fa].son[1];
}
inline void pushup(int x){
    t[x].size=t[t[x].son[0]].size+t[t[x].son[1]].size+t[x].cnt;
}
inline void rotate(int x){
    int y=t[x].fa,z=t[y].fa,xpos=get(x),ypos=get(y);
    t[z].son[ypos]=x;
    t[x].fa=z;
    t[y].son[xpos]=t[x].son[xpos^1];
    t[t[x].son[xpos^1]].fa=y;
    t[x].son[xpos^1]=y;
    t[y].fa=x;
    pushup(y); pushup(x);
}
inline void splay(int x,int goal){
    while(t[x].fa!=goal){
        int y=t[x].fa,z=t[y].fa,xpos=get(x),ypos=get(y);
        if(z!=goal){
            if(xpos==ypos) rotate(y);
            else rotate(x);
        }
        rotate(x);
    }
    if(goal==0) root=x;
}
inline void insert(int x){
    int u=root,fa=0;
    while(u && t[u].val!=x){
        fa=u;
        u=t[u].son[x>t[u].val];
    }
    if(u) t[u].cnt++;
    else{
        u=++tot;
        if(fa) t[fa].son[x>t[fa].val]=u;
        t[u].son[0]=t[u].son[1]=0;
        t[u].fa=fa;
        t[u].val=x;
        t[u].cnt=t[u].size=1;
    }
    splay(u,0);
}
inline void find(int x){
    int u=root;
    if(!u) return;
    while(t[u].son[x>t[u].val] && x!=t[u].val)
        u=t[u].son[x>t[u].val];
    splay(u,0);
}
int pre_nxt(int x,int op){
    find(x);
    int u=root;
    if(t[u].val>x && op) return u;
    if(t[u].val<x && !op) return u;
    u=t[u].son[op];
    while(t[u].son[op^1]) u=t[u].son[op^1];
    return u;
}
inline void del(int x){
    int pre=pre_nxt(x,0),nxt=pre_nxt(x,1);
    splay(pre,0);
    splay(nxt,pre);
    int pos=t[nxt].son[0];
    if(t[pos].cnt>1){
        t[pos].cnt--;
        splay(pos,0);
    }
    else t[nxt].son[0]=0;
}
int kth(int x){
    int u=root;
    if(t[u].size<x) return 0;
    while(1){
        int lson=t[u].son[0];
        if(x>t[lson].size+t[u].cnt){
            x-=t[lson].size+t[u].cnt;
            u=t[u].son[1];
        }
        else if(t[lson].size>=x) u=lson;
        else return t[u].val; 
    }
}

int main(){
    n=read(); m=read();
    insert(-inf); insert(inf);
    for(int i=1;i<=n;i++) insert(read());
    int last=0,ans=0;
    while(m--){
        int op=read(),x=read()^last;
        if(op==1) insert(x);
        if(op==2) del(x);
        if(op==3){
            find(x);
            last=t[t[root].son[0]].size;
            ans^=last;
        }
        if(op==4){
            last=kth(x+1);
            ans^=last;
        }
        if(op==5){
            last=t[pre_nxt(x,0)].val;
            ans^=last;
        }
        if(op==6){
            last=t[pre_nxt(x,1)].val;
            ans^=last;
        }
    }
    printf("%d\n",ans);
    return 0;
}

|