Splay 32pts 求调

P6136 【模板】普通平衡树(数据加强版)

Peaky @ 2024-08-16 12:02:30

insert 是插入
remove 是删除
queryRank 查询排名
findKth 找第 K 小
findNxt 找后继
findPre 找前驱

visitVal DEBUG

#include<bits/stdc++.h>
#define int long long
#define x first
#define y second
using namespace std;
namespace FastIO{
    inline int read(){
        int s=0,f=1;char t=getchar();
        while('0'>t||t>'9'){if(t=='-')f=-1;t=getchar();}
        while('0'<=t&&t<='9'){s=(s<<1)+(s<<3)+t-'0';t=getchar();}
        return s*f;
    }
}
using FastIO::read;
const int N=2e7+10,inf=0x3f3f3f3f,mod=1e9+7;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef long long ll;
/*

*/
int ch[N][2],fa[N],val[N],siz[N],num[N];
int tot,rt,n,m,ans,t;

void visitVal(int x){
    if(ch[x][0])visitVal(ch[x][0]);
    if(val[x]!=inf&&val[x]!=-inf){
        cout<<val[x];
        if(num[x]>1) cout<<"("<<num[x]<<")";
        cout<<" ";
    }
    else if(val[x]==inf)           cout<<"+inf\n";
    else                           cout<<"-inf ";
    if(ch[x][1])visitVal(ch[x][1]);
}

void visitNum(int x){
    if(ch[x][0])visitNum(ch[x][0]);
    cout<<num[x]<<" ";
    if(ch[x][1])visitNum(ch[x][1]);
}

void pushUp(int x){
    siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+num[x];
}

void rotate(int x,int &f){
    int y=fa[x],z=fa[y],L=(ch[y][0]!=x),R=(L^1);
    if(y==f)f=x;else if(ch[z][0]==y)ch[z][0]=x;else ch[z][1]=x;
    fa[x]=z;fa[y]=x;fa[ch[x][R]]=y;
    ch[y][L]=ch[x][R];ch[x][R]=y;
    pushUp(y);pushUp(x);
}

void Splay(int x,int &f){
    while(x!=f){
        int y=fa[x],z=fa[y];
        if(y!=f){
            if((ch[y][0]==x)^(ch[z][0]==y))rotate(x,f);
            else rotate(y,f);
        }
        rotate(x,f);
    }
}
void init(){
    rt=1;tot=2;
    fa[2]=1;ch[1][1]=2;
    val[1]=-inf;
    val[2]=+inf;
    siz[2]=1;num[1]=1;
    siz[1]=2;num[2]=1;
}

int insert(int& x,int f,int v){
    if(!x){x=++tot;val[x]=v;siz[x]=num[x]=1;fa[x]=f;return x;}
    if(val[x]==v){num[x]++;return x;}
    else if(val[x]<v)return insert(ch[x][1],x,v);
    else return insert(ch[x][0],x,v);
}

void insert(int v){
    int x=insert(rt,0,v);
    Splay(x,rt);
}

int findPrev(){
    int x=ch[rt][0];
    while(ch[x][1])x=ch[x][1];
    return x;
}

int findNext(){
    int x=ch[rt][1];
    while(ch[x][0])x=ch[x][0];
    return x;
}

int find(int x,int v){
    if(val[x]==v)return x;
    if(val[x]>v)return find(ch[x][0],v);
    return find(ch[x][1],v);
}

void remove(int v){
    int x=find(rt,v);
    Splay(x,rt);
    if(num[x]>1){
        num[x]--;
        return;
    }
    int prev=findPrev();
    int next=findNext();
    Splay(prev,rt);
    Splay(next,ch[rt][1]);
    ch[next][0]=fa[x]=0;
    pushUp(next);pushUp(prev);
}

int queryRank(int v){
    insert(v);
    int x=find(rt,v);
    Splay(x,rt);
    int ans=siz[ch[x][0]];
    remove(v);
    return ans;
}

int findKth(int x,int k){
    if(siz[ch[x][0]]>=k)return findKth(ch[x][0],k);
    if(siz[ch[x][0]]+num[x]>=k)return x;
    return findKth(ch[x][1],k-siz[ch[x][0]]-num[x]);
}
int findKth(int k){
    return val[findKth(rt,k+1)];
}

int findPre(int v){
    insert(v);
    int prev=findPrev();
    int ans=val[prev];
    remove(v);
    return ans;
}

int findNxt(int v){
    insert(v);
    int next=findNext();
    int ans=val[next];
    remove(v);
    return ans;
}

signed main(){
    init();
    n=read();
    m=read();
    for(int i=1;i<=n;i++) t=read(),insert(t);
    int lst=0;
    while(m--){
        int opt,x;
        opt=read();
        x=read();
        x=x^lst;
        if(opt==1) insert(x);
        if(opt==2) remove(x);
        if(opt==3) lst=queryRank(x);
        if(opt==4) lst=findKth(x);
        if(opt==5) lst=findPre(x);
        if(opt==6) lst=findNxt(x);
        if(opt>2) ans^=lst;
    }
    printf("%lld",ans);
    return 0;
}

by Liuhy2996 @ 2024-08-16 15:15:13

???为什么你会WA,我对着你代码过的P3369


by Peaky @ 2024-08-17 16:49:36

@Hangyu2011 不知道,我只过了 P3369


|