求助 splay

P6136 【模板】普通平衡树(数据加强版)

wyx__ @ 2021-02-23 17:14:49

全都MLE+RE

自己测试了第一个点,输出也没有问题,感觉也没有哪里会MLE的样子,P3369也过了。

代码:

#include<bits/stdc++.h>
using namespace std;
int val[1100005],lc[1100005],rc[1100005],fa[1100005],size[1100005],cnt[1100005],rt,tot,n,m,ans,lastans;
void update(int x){
    size[x]=size[lc[x]]+size[rc[x]]+cnt[x];
}
void rotate(int x) {
    int y=fa[x],z=fa[y];
    int b=(x==lc[y])?rc[x]:lc[x];
    fa[x]=z,fa[y]=x;
    if(b)fa[b]=y;
    if(z)(y==lc[z]?lc[z]:rc[z])=x;
    if(x==lc[y])rc[x]=y,lc[y]=b;
    else lc[x]=y,rc[y]=b;
    update(y),update(x);
}
bool wrt(int x){
    return x==rc[fa[x]];
}
void splay(int x,int target){
    while(fa[x]!=target){
        int y=fa[x],z=fa[y];
        if(z!=target){
            if(wrt(x)==wrt(y))rotate(y);
            else rotate(x);
        }
        rotate(x);
    }
    if(target==0)rt=x;
}
int find(int v){
    int x=rt;
    while(x){
        if(val[x]==v)break;
        if(val[x]<v)x=rc[x];
        else x=lc[x];
    }
    if(x)splay(x,0);
    return x;
}
int insert(int v){
    int x=rt,y=0,dir;
    while(x){
        y=x;
        size[x]++;
        if(v==val[x])break;
        if(val[x]>v)x=lc[x],dir=0;
        else x=rc[x],dir=1;
    }
    if(x)cnt[x]++;
    else{
        x=++tot;
        fa[x]=y,size[x]++,cnt[x]++,val[x]=v;
        if(y)(dir==0?lc[y]:rc[y])=x;
        splay(x,0);
    }   
}
void join(int x,int y){
    fa[x]=fa[y]=0;
    int w=x;
    while(rc[w])w=rc[w];
    splay(w,0);
    rc[w]=y,fa[y]=w;
    update(w);
}
void Delete(int x){
    cnt[x]--,size[x]--;
    if(cnt[x]==0){
        splay(x,0);
        if(!lc[x]||!rc[x])fa[rt=lc[x]+rc[x]]=0;
        else join(lc[x],rc[x]);
        lc[x]=rc[x]=0; 
    }
}
int val_to_rank(int v){
    int x=rt,ans=1;
    while(x){
        if(val[x]==v){
            ans+=size[lc[x]];
            splay(x,0);
            break;
        }
        if(v<val[x])x=lc[x];
        else {
            ans+=size[lc[x]]+cnt[x];
            x=rc[x];
        }
    }
    return ans;
}
int rank_to_val(int v){
    int x=rt;
    while(1){
        int temp1=size[x]-size[rc[x]],temp2=size[lc[x]];
        if(v>temp2&&v<=temp1)break;
        if(v<=temp2)x=lc[x];
        else {
            v-=temp1;
            x=rc[x];
        }
    }
    splay(x,0);
    return val[x];
}
int lower(int v){
    int x=rt,res=-1e9;
    while(x){
        if(val[x]<v&&val[x]>res)res=val[x];
        if(v>val[x])x=rc[x];
        else x=lc[x];
    }
    return res;
}
int upper(int v){
    int x=rt,res=1e9;
    while(x){
        if(val[x]>v&&val[x]<res)res=val[x];
        if(v<val[x])x=lc[x];
        else x=rc[x];
    }
    return res;
}
void debug(){
    for(int i=1;i<=tot;i++)
        cout<<"node:"<<i<<" val:"<<val[i]<<" fa:"<<fa[i]<<" lc:"<<lc[i]<<" rc:"<<rc[i]<<" cnt:"<<cnt[i]<<" size:"<<size[i]<<endl;
}
int main() {
    //freopen("P6136_1.in","r",stdin);
    //freopen("P6136_1.out","w",stdout);
    cin>>n>>m;
    for(int i=1;i<=n;i++){
        int x;
        scanf("%d",&x);
        insert(x);
    }
    for(int i=1;i<=m;i++){
        int opt,x;
        scanf("%d%d",&opt,&x);
        x^=lastans;
        //cout<<opt<<' '<<x<<endl; 
        if(opt==1)insert(x);
        if(opt==2)Delete(find(x));
        if(opt==3)lastans=val_to_rank(x);
        if(opt==4)lastans=rank_to_val(x);
        if(opt==5)lastans=lower(x);
        if(opt==6)lastans=upper(x);
        //if(opt>2)cout<<lastans<<endl;
        if(opt>2)ans^=lastans;
        //debug();
    }
    cout<<ans;
}

|