替罪羊树 WA96pts ,玄关,码风正常

P6136 【模板】普通平衡树(数据加强版)

_JoeyJ_ @ 2023-07-14 20:13:44

#include<bits/stdc++.h>
using namespace std;
#define int long long

const int N=2e6+9;
const double A=0.75;

class goat{
    public:
        int root;
        int pcnt;
    private:
        int lc[N],rc[N];
        int siz[N],tot[N],del[N];
        int w[N],cnt[N];

        void calc_size(int x){
            siz[x]=siz[lc[x]]+siz[rc[x]]+1;
            tot[x]=tot[lc[x]]+tot[rc[x]]+cnt[x];
            del[x]=del[lc[x]]+del[rc[x]]+bool(x);
        }

        int rft[N];
        bool rebuild_check(int x){
            return cnt[x]&&(max(siz[lc[x]],siz[rc[x]])>=A*siz[x]||del[x]<=A*siz[x]);
        }
        void rebuild_flat(int x,int &c){
            if(!x) return ;
            this->rebuild_flat(lc[x],c);
            if(cnt[x]) rft[++c]=x;
            this->rebuild_flat(rc[x],c);
        }
        int rebuild_build(int l,int r){
            if(l>=r) return 0;
            int mid=(l+r)/2;
            int x=rft[mid];
            lc[x]=this->rebuild_build(l,mid);
            rc[x]=this->rebuild_build(mid+1,r);
            this->calc_size(x);
            return x;
        }
        void rebuild(int &x){
            int sum=0;
            this->rebuild_flat(x,sum);
            x=this->rebuild_build(1,sum+1);
        }

        int lower_bound(int x,int k){
            if(!x) return 0;
            if(k==w[x]) return tot[lc[x]];
            else if(k>w[x]) return tot[lc[x]]+cnt[x]+lower_bound(rc[x],k);
            else return lower_bound(lc[x],k);
        }

        int upper_bound(int x,int k){
            if(!x) return 1;
            if(k==w[x]) return tot[lc[x]]+cnt[x]+1;
            else if(k>w[x]) return tot[lc[x]]+cnt[x]+upper_bound(rc[x],k);
            else return upper_bound(lc[x],k);
        }

    public:
        void insert_(int &x,int k){
            if(!x){
                x=++pcnt;
                w[x]=k;
                if(!root) root=x;
                lc[x]=rc[x]=0;
                cnt[x]=siz[x]=del[x]=tot[x]=1;
                return ;
            }
            if(w[x]==k) cnt[x]++;
            else if(w[x]<k) insert_(rc[x],k);
            else insert_(lc[x],k);
            this->calc_size(x);
            if(this->rebuild_check(x)) this->rebuild(x);
        }

        void delete_(int &x,int k){
            if(!x) return ;
            if(w[x]==k) cnt[x]--;
            else if(w[x]<k) delete_(rc[x],k);
            else delete_(lc[x],k);
            this->calc_size(x);
            if(this->rebuild_check(x)) this->rebuild(x);
        }

        int rank_(int x,int k){
            if(!x) return 1;
            if(w[x]==k) return tot[lc[x]]+1;
            else if(w[x]<k) return tot[lc[x]]+cnt[x]+rank_(rc[x],k);
            else return rank_(lc[x],k);
        }

        int at_(int x,int k){
            if(!x) return 0;
            if(tot[lc[x]]<k&&k<=tot[lc[x]]+cnt[x]) return w[x];
            else if(k>tot[lc[x]]+cnt[x]) return at_(rc[x],k-tot[lc[x]]-cnt[x]);
            else return at_(lc[x],k);
        }

        inline int predecessor_(int x,int k){
//          cout<<lower_bound(x,k)<<endl;
            return at_(x,this->lower_bound(x,k));
        }

        inline int successor_(int x,int k){
//          cout<<upper_bound(x,k)<<endl;
            return at_(x,this->upper_bound(x,k));
        }

        void print(){
            cout<<"root : "<<root<<endl;
            cout<<"total : "<<pcnt<<endl;
            for(int i=1;i<=pcnt;i++) cout<<w[i]<<' ';cout<<endl;
            for(int i=1;i<=pcnt;i++) cout<<lc[i]<<' ';cout<<endl;
            for(int i=1;i<=pcnt;i++) cout<<rc[i]<<' ';cout<<endl;
            for(int i=1;i<=pcnt;i++) cout<<cnt[i]<<' ';cout<<endl;
            for(int i=1;i<=pcnt;i++) cout<<tot[i]<<' ';cout<<endl;
            for(int i=1;i<=pcnt;i++) cout<<siz[i]<<' ';cout<<endl;
            for(int i=1;i<=pcnt;i++) cout<<del[i]<<' ';cout<<endl;
        }
};
goat tr;
signed main(){
    int n,m;
    cin>>n>>m;
    for(int i=1;i<=n;i++){
        int x;
        cin>>x;
        tr.insert_(tr.root,x);
    }
    int l=0,sm=0;
    for(int i=1;i<=m;i++){
        int op,x;
        cin>>op>>x;
        x^=l;
        int ans=0;
        if(op==1) tr.insert_(tr.root,x);
        if(op==2) tr.delete_(tr.root,x);
        if(op==3) ans=tr.rank_(tr.root,x);
        if(op==4) ans=tr.at_(tr.root,x);
        if(op==5) ans=tr.predecessor_(tr.root,x);
        if(op==6) ans=tr.successor_(tr.root,x);
        if(ans) l=ans;
        sm^=ans;
//      tr.print();
    }
    cout<<sm<<endl;
}

// 40:18.27

WA on #17,18


|