萌新刚学splay,求助

P6136 【模板】普通平衡树(数据加强版)

Sola_ @ 2021-09-28 17:36:45

样例输出5,点1,4过了,其他WA了

#include<bits/stdc++.h>
using namespace std;
const int N=1100009,INF=0x7fffffff;
int n,m,tmp,rt,id;
long long last,ans;
long long t[N][2];
int fa[N],siz[N],cnt[N],val[N];

int ask_dir(int u){
    return t[fa[u]][1]==u;
}

void connect(int u,int f,int p){
    if(u!=0) fa[u]=f;
    if(f!=0) t[f][p]=u;
}

void update(int u){
    siz[u]=siz[t[u][0]]+siz[t[u][1]]+cnt[u];
}

void rotate(int u){
    int f=fa[u],gf=fa[f];
    int dir=ask_dir(u),fdir=ask_dir(f);
    int ano_son=t[u][!dir];
    connect(ano_son,f,dir);
    connect(u,gf,fdir);
    connect(f,u,!dir);
    update(f);
    update(u);
}

void splay(int u,int end){
    for(int useless;fa[u]!=end;rotate(u))
        if(fa[fa[u]]!=end&&ask_dir(fa[u])==ask_dir(u))
            rotate(fa[u]);
    if(end==0) rt=u;
}

void insert(int x){
    int u=rt;
    if(!rt){
        rt=++id;
        val[id]=x;
        siz[id]=cnt[id]=1;
        return;
    }
    while(val[u]!=x){
        siz[u]++;
        if(x<=val[u]){
            if(t[u][0]==0){
                val[++id]=x;
                connect(id,u,0);
            }
            u=t[u][0];
        }
        else if(x>val[u]){
            if(t[u][1]==0){
                val[++id]=x;
                connect(id,u,1);
            }
            u=t[u][1];
        }
    }
    siz[u]++,cnt[u]++;
    splay(u,0);
}

void ask_sort(int x){
    int u=rt;
    if(!u) return;
    while(t[u][x>val[u]]&&x!=val[u])
        u=t[u][x>val[u]];
    splay(u,0);
}

int ask_pre_nxt(int x,int typ){
    ask_sort(x);
    int u=rt;
    if((val[u]<x&&!typ)||(val[u]>x&&typ)) return u;
    u=t[u][typ];
    while(t[u][typ^1]) u=t[u][typ^1];
    return u;
}

void cut_off(int x){
    int pre=ask_pre_nxt(x,0);
    int nxt=ask_pre_nxt(x,1);
    splay(pre,0);
    splay(nxt,pre);
    int mission=t[nxt][0];
    if(cnt[mission]>1){
        cnt[mission]--;
        splay(mission,0);
    }
    else
        t[nxt][0]=0;
}

int ask_k_val(int x){
    int u=rt;
    if(siz[u]<x){
        while(t[u][1])
            u=t[u][1];
        return u;
    }
    while(1){
        int dir=t[u][0];
        if(x>siz[dir]+cnt[u]){
            x-=siz[dir]+cnt[u];
            u=t[u][1];
        }
        else if(x<=siz[dir])
            u=t[u][0];
        else
            return val[u];
    }
}

int main(){
    ios::sync_with_stdio(false);
    insert(-INF);
    insert(+INF);
    cin>>n>>m;
    for(int i=1;i<=n;i++){
        cin>>tmp;
        insert(tmp);
    }
    for(int i=1;i<=m;i++){
        int opt,x;
        cin>>opt>>x;
        x^=last;
        if(opt==1)
            insert(x);
        else if(opt==2)
            cut_off(x);
        else if(opt==3){
            ask_sort(x);
            last=siz[t[rt][0]];
            ans^=last;
        }
        else if(opt==4){
            last=ask_k_val(x+1);
            ans^=last;
        }
        else if(opt==5){
            last=ask_pre_nxt(x,0);
            ans^=val[last];
        }
        else if(opt==6){
            last=ask_pre_nxt(x,1);
            ans^=val[last];
        }
    }
    cout<<ans<<endl;
    return 0;
} 

|