Treap 40pts求助

P6136 【模板】普通平衡树(数据加强版)

Mirasycle @ 2022-08-18 13:45:59

感觉操作4出了问题

#include<iostream>
#include<cstring>
#include<algorithm>
#include<vector>
#include<cstdio>
using namespace std;
const int maxn=1e5+10;
const int maxm=2e6+10;
const int inf=0x3f3f3f3f;
int root,last=0,ans=0;
struct node{
    int l,r;
    int rank,val;
    int cnt,size;
}a[maxm];
struct Treap{
    int tot;
    int newnode(int v){
        a[++tot].val=v;
        a[tot].rank=rand();
        a[tot].cnt=a[tot].size=1;
        return tot;
    }
    void update(int u){
        a[u].size=a[a[u].l].size+a[a[u].r].size+a[u].cnt;
    }
    void build_tree(){
        tot=0; root=1;
        newnode(-inf);
        newnode(inf);
        a[1].r=2;
        update(1);
    }
    void rotate_left(int &p){//左旋 
        int q=a[p].r;
        a[p].r=a[q].l;
        a[q].l=p;
        p=q;
        update(a[p].l); update(p);
    }
    void rotate_right(int &p){//右旋 
        int q=a[p].l;
        a[p].l=a[q].r;
        a[q].r=p;
        p=q;
        update(a[p].r); update(p);
    }
    void insert(int &p,int v){
        if(p==0){
            p=newnode(v);
            return ;
        }
        if(v<a[p].val){
            insert(a[p].l,v);
            if(a[a[p].l].rank>a[p].rank) rotate_right(p); 
        }else if(v==a[p].val){
            a[p].cnt++;
        }else{
            insert(a[p].r,v);
            if(a[a[p].r].rank>a[p].rank) rotate_left(p);
        }
        update(p);
        return ;
    }
    void remove(int &p,int v){
        if(p==0) return ;
        if(v<a[p].val){
            remove(a[p].l,v);
        }else if(v==a[p].val){
            if(a[p].cnt>1){
                a[p].cnt--;
            }else if(a[p].l||a[p].r){
                if(a[p].r==0||a[a[p].r].rank<a[a[p].l].rank){
                    rotate_right(p);
                    remove(a[p].r,v);
                }else{
                    rotate_left(p);
                    remove(a[p].l,v);
                }
            }else{
                p=0;
            }
        }else{
            remove(a[p].r,v);
        }
        update(p);
    }
    int getrank(int p,int v){
        if(p==0) return 1;
        if(v==a[p].val) return a[a[p].l].size+1;
        if(v<a[p].val) return getrank(a[p].l,v);
        return getrank(a[p].r,v)+a[a[p].l].size+a[p].cnt;       
    }
    int getval(int p,int ra){
        if(p==0) return 0;
        if(ra<=a[a[p].l].size) return getval(a[p].l,ra);
        if(a[a[p].l].size+a[p].cnt>=ra) return a[p].val;
        return getval(a[p].r,ra-a[a[p].l].size-a[p].cnt);
    }
    int getpre(int v){
        int ans=1;
        int p=root;
        while(p){
            if(v==a[p].val){
                if(a[p].l>0){
                    p=a[p].l;
                    while(a[p].r>0) p=a[p].r; 
                    ans=p;
                }
                break;
            }
            if(a[p].val<v&&a[p].val>a[ans].val) ans=p;
            p=v<a[p].val?a[p].l:a[p].r;
        }
        return a[ans].val;
    }
    int getnext(int v){
        int ans=2;
        int p=root;
        while(p){
            if(v==a[p].val){
                if(a[p].r>0){
                    p=a[p].r;
                    while(a[p].l>0) p=a[p].l; 
                    ans=p;
                }
                break;
            }
            if(a[p].val>v&&a[p].val<a[ans].val) ans=p;
            p = v<a[p].val?a[p].l:a[p].r;
        }
        return a[ans].val;
    }
}tree;
void work(int op,int x){
    switch(op){
        case 1:
            tree.insert(root,x);
            break;
        case 2:
            tree.remove(root,x); 
            break;
        case 3:
            last=tree.getrank(root,x)-1;
            break;
        case 4:
            last= tree.getval(root,x+1);
            break;
        case 5:
            last=tree.getpre(x);
            break;
        case 6:
            last=tree.getnext(x);
            break;
    }
    return ;
}
int main(){
    tree.build_tree();
    int n,m;
    cin>>n>>m;
    for(register int i=1;i<=n;i++){
        int x;
        scanf("%d",&x);
        tree.insert(root,x);
    }
    for(register int i=1;i<=m;i++){
        int op,x;
        scanf("%d%d",&op,&x);
        x=x^last;
        work(op,x);
        if(op!=1&&op!=2) ans=ans^last;
        //cout<<"ans:"<<ans<<endl;
        //cout<<"last:"<<last<<endl;
        //cout<<"x:"<<x<<endl; 
    }
    cout<<ans<<endl;
    return 0;
}

by Mirasycle @ 2022-08-19 09:55:54

已AC,此贴完结 顺便警告一下inf取0x3f3f3f3f和1e9都不行,必须INT_MAX


|