splay 68pts,T了五个点,想问下哪里写假了(码风清晰)

P6136 【模板】普通平衡树(数据加强版)

eastcloud @ 2022-10-09 17:25:30

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<algorithm>
#include<set>
using namespace std;
int fa[4000001];
int ch[4000001][2];
int val[4000001],cnt[4000001];
int siz[4000001];
int root,tot;
inline int read(){
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-')
            f=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9'){
        x=(x<<1)+(x<<3)+(ch^48);
        ch=getchar();
    }
    return x*f;
}
inline void clear(int x){
    siz[x]=val[x]=cnt[x]=fa[x]=ch[x][0]=ch[x][1]=0;
}
inline bool get_son(int x){
    return ch[fa[x]][1]==x;
}
inline void update(int x){
    siz[x]=cnt[x];
    if(ch[x][0]) siz[x]+=siz[ch[x][0]];
    if(ch[x][1]) siz[x]+=siz[ch[x][1]];
}
inline void rotate(int x){
    int f=fa[x],gra=fa[fa[x]],k=get_son(x);
    ch[f][k]=ch[x][k^1];fa[ch[x][k^1]]=f;
    ch[x][k^1]=f;fa[f]=x;
    fa[x]=gra;
    if(gra) ch[gra][ch[gra][1]==f]=x;
    update(f);update(x);
}
inline void splay(int x){
    for(int f;f=fa[x];rotate(x)){
        if(fa[f]) rotate(get_son(x)==get_son(f)?f:x);
    }
    root=x;
}
inline int New(int v){
    val[++tot]=v;
    siz[tot]=cnt[tot]=1;
    return tot;
}
inline void insert(int v){
    if(!root){
        root=New(v);
        return;
    }
    int now=root,f=0;
    while(1){
        if(val[now]==v){
            cnt[now]++;
            update(now);
            update(f);
            splay(now);
            return;
        }
        f=now;now=ch[now][v>val[now]];
        if(!now){
            int x=New(v);
            fa[x]=f;
            ch[f][v>val[f]]=x;
            splay(x);
            return;
        }
    }
}
inline int find_rank(int x){
    int now=root,rank=0;
    while(1){
        if(val[now]==x){
            splay(now);
            return siz[ch[root][0]]+rank+1;
        }
        else{
            if(val[now]<x){
                rank+=cnt[now];
                if(ch[now][0])rank+=siz[ch[now][0]];
                now=ch[now][1];
            }
            else now=ch[now][0];
        }
    }
}
inline int find_num(int x){
    int now=root;
    while(1){
        if(ch[now][0] && x<=siz[ch[now][0]]) now=ch[now][0];
        else{
            int tmp=cnt[now]+(ch[now][0]?siz[ch[now][0]]:0);
            if(x<=tmp){
                splay(now);
                return val[root];
            }
            x-=tmp;
            now=ch[now][1];
        }
    }
}
inline int pre(){
    int now=ch[root][0];
    while(ch[now][1])now=ch[now][1];
    splay(now);
    return root;
}
inline int nxt(){
    int now=ch[root][1];
    while(ch[now][0]) now=ch[now][0];
    splay(now);
    return root;
}
inline void del(int x){
    int t=find_rank(x);
    if(cnt[root]>1){
        cnt[root]--;
        update(root);
        return;
    }
    if(!ch[root][0] && !ch[root][1]){
        clear(root);
        root=0;
        return;
    }
    if(!ch[root][0]){
        int tmp=root;
        root=ch[root][1];
        fa[root]=0;
        clear(tmp);
        return;
    }
    if(!ch[root][1]){
        int tmp=root;
        root=ch[root][0];
        fa[root]=0;
        clear(tmp);
        return;
    }
    int tmp=root,pre_t=pre();
    fa[ch[tmp][1]]=root;
    ch[root][1]=ch[tmp][1];
    clear(tmp);
    update(root);
}
int main(){
    int ans=0,n,m,opt,last=0,x;
    n=read();m=read();
    for(int i=1;i<=n;i++){
        x=read();
        insert(x);
    }
    for(int i=1;i<=m;i++){
        opt=read();x=read();
        x=x^last;
        if(opt==1) insert(x);
        else if(opt==2) del(x);
        else if(opt==3) {
            insert(x);
            if(cnt[root]==1){
                int tmp=pre();
                last=find_rank(val[tmp])+cnt[tmp];
                del(x);
            }
            else{
                del(x);
                last=find_rank(x);
            }
            ans=ans^last;
        }
        else if(opt==4){
            last=find_num(x);
            ans=ans^last;
        }
        else if(opt==5){
            insert(x);
            last=val[pre()];
            ans=ans^last;
            del(x);
        }
        else if(opt==6){
            insert(x);
            last=val[nxt()];
            ans=ans^last;
            del(x);
        }
    }
    cout<<ans<<endl;
}

|