splay AC#1,3,4,其余TLE 求助

P6136 【模板】普通平衡树(数据加强版)

j1ANGFeng @ 2022-07-13 16:59:58

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<string>
#define ll long long
#define N 10000001
#define int long long
using namespace std;
inline long long rd(){char a=getchar();long long f=1,x=0;while(a<'0'||a>'9'){if(a=='-')f=-1;a=getchar();}while(a>='0'&&a<='9'){x=(x<<3)+(x<<1)+(long(a^48));a=getchar();}return f*x;}void qwqqwq(long long x){if(x!=0){qwqqwq(x/10);putchar(x%10^48);}return;}inline void wt(long long x){if(x==0){putchar('0');return;}if(x<0){x=-x;putchar('-');}qwqqwq(x);return;}
int cnt=0;
struct splay{
    int s[2],fa,si,w,cnt;
}t[N];
#define root t[0].s[1]
int up(int i){
    t[i].si=t[t[i].s[0]].si+t[t[i].s[1]].si+t[i].cnt;
    return t[i].si;
}
bool id(int i){
    if(t[t[i].fa].s[0]==i)
      return 0;
    return 1;
}
void dad(int x,int f,int i){
    t[f].s[i]=x;
    t[x].fa=f;
    return;
}
void ro(int i){
    int y=t[i].fa,r=t[y].fa;
    int y1=id(i),r1=id(y);
    dad(t[i].s[y1^1],y,y1);
    dad(y,i,y1^1);
    dad(i,r,r1);
    up(y);
    up(i);
    return;
}
void splay(int x,int to){
    to=t[to].fa;
    int y=t[x].fa;
    while(t[x].fa!=to){
        int y=t[x].fa;
        if(t[y].fa==to)
          ro(x);
        else if(id(y)==id(x)){
            ro(y);
            ro(x);
        }else{
            ro(x);
            ro(x);
        }
    }
    return;
}
int newnode(int w,int f){
    t[++cnt].fa=f;
    t[cnt].si=t[cnt].cnt=1;
    t[cnt].w=w;
    return cnt;
}
void ins(int x){
    int now=root;
    if(root==0){
        root=newnode(x,0);
        return;
    }
    while(1){
        ++t[now].si;
        if(t[now].w==x){
            ++t[now].cnt;
            splay(now,root);
            return;
        }else{
            int nxt=1;
            if(x<t[now].w)
              nxt=0;
            if(t[now].s[nxt]==0){
                splay(t[now].s[nxt]=newnode(x,now),root);
                return;
            }
            now=t[now].s[nxt];
        }
    }
    return;
}
int find(int x){
    int now=root;
    while(1){
        if(!now)
          return 0;
        if(t[now].w==x){
            splay(now,root);
            return now;
        }
        int nxt=1;
        if(t[now].w>x)
          nxt=0;
        now=t[now].s[nxt];
    }
    return 0;
}
void del(int x){
    int pos=find(x);
    if(pos==0)
      return;
    if(t[pos].cnt>1){
        --t[pos].cnt;
        --t[pos].si;
        return;
    }
    if(t[pos].s[0]==0&&t[pos].s[1]==0){
        root=0;
        return;
    }
    if(t[pos].s[0]==0){
        root=t[pos].s[1];
        t[root].fa=0;
        return;
    }
    int l=t[pos].s[0];
    while(t[l].s[1])
      l=t[l].s[1];
    splay(l,t[pos].s[0]);
    dad(t[pos].s[1],l,1);
    dad(l,0,1);
    up(l);
    return;
}
int rankk(int x){
    return t[t[find(x)].s[0]].si+1;
}
int val(int x){
    int now=root;
    while(true){
        int num=t[now].si-t[t[now].s[1]].si;
        if(x>t[t[now].s[0]].si&&x<=num){
            splay(now,root);
            return t[now].w;
        }
        if(x<num)
          now=t[now].s[0];
        else now=t[now].s[1],x-=num;
    }
    return 0;
}
int lower(int x){
    int now=root,ans=-2147483647;
    while(now){
        if(t[now].w<x)
          ans=max(ans,t[now].w);
        if(t[now].w>=x)
          now=t[now].s[0];
        else now=t[now].s[1];
    }
    return ans;
}
int upper(int x){
    int now=root,ans=2147483647;
    while(now){
        if(t[now].w>x){
            ans=min(ans,t[now].w);
        }
        if(t[now].w<=x)
          now=t[now].s[1];
        else now=t[now].s[0];
    }
    return ans;
}
signed main(){
    int n=rd(),m=rd()+1,ans=0,sum=0;
    for(int i=1;i<=n;++i){
        int k=rd();
        ins(k);
    }
    while(--m){
        int lx=rd(),x=rd()^sum;
        if(lx==1)
          ins(x);
        if(lx==2)
          del(x);
        if(lx==3)
          sum=rankk(x);
        if(lx==4)
          sum=val(x);
        if(lx==5)
          sum=lower(x);
        if(lx==6)
          sum=upper(x);
        if(lx>2)
          ans^=sum;
    }
    wt(ans);
    return 0;
}

by 小熙熙 @ 2022-07-17 14:33:47

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<string>
#include<cmath>
#include<cstring>
#include<queue>
#include<map>
#include<vector>
#define bug cout<<"bug"<<endl
#define ll long long
#define inf 0x3f3f3f3f
#define mod 1000000007
using namespace std;
inline int read(){  int x=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
    return x*f;
}
const int maxn=1e6+100000;
int na;
int ch[maxn][2];
int val[maxn],dat[maxn];
int size[maxn],cnt[maxn];
int tot,root;
int New(int v){
    ++tot;
    val[tot]=v;
    dat[tot]=rand();
    size[tot]=1;
    cnt[tot]=1;
    return tot;
}
void pushup(int id){
    size[id]=size[ch[id][0]]+size[ch[id][1]]+cnt[id];
}
void build(){
    root=New(-inf);
    ch[root][1]=New(inf);
    pushup(root);
}
void rotate(int &id,int d){//id 为某一子树的根 
    int temp=ch[id][d^1];
    ch[id][d^1]=ch[temp][d];
    ch[temp][d]=id;
    id=temp;
    pushup(ch[id][d]);
    pushup(id); 
}
void insert(int &id,int v){
    if(!id){
        id=New(v);
        return ;
    }
    if(v==val[id]) cnt[id]++;
    else{
        int d=v<val[id]?0:1;
        insert(ch[id][d],v);
        if(dat[id]<dat[ch[id][d]]){
            rotate(id,d^1);
        }
    }
    pushup(id); 
}
void remove(int &id,int v){
    if(id==0) return;
    if(v==val[id]){
        if(cnt[id]>1){
            cnt[id]--;
            pushup(id);
            return ;
        }
        if(ch[id][0]||ch[id][1]){
            if(!ch[id][1]||dat[ch[id][0]]>dat[ch[id][1]]){
                rotate(id,1);
                remove(ch[id][1],v);
            }
            else{
                rotate(id,0);
                remove(ch[id][0],v);
            }
            pushup(id);
        }
        else id=0;
        return ;
    }
    v<val[id]?remove(ch[id][0],v):remove(ch[id][1],v);
    pushup(id);
}
int getpre(int v){
    int id=root,pre;
    while(id!=0){
        if(val[id]<v){
            pre=val[id];
            id=ch[id][1];
        }
        else id=ch[id][0];
    }
    return pre;
}
int getnext(int v){
    int id=root,next;
    while(id!=0){
        if(val[id]>v){
            next=val[id];
            id=ch[id][0];
        }
        else id=ch[id][1];
    }
    return next;
}
int getrank(int id,int v){
    if(id==0) return 1;
    if(v==val[id]) return size[ch[id][0]]+1;
    else if(v<val[id]) return getrank(ch[id][0],v);
    else return size[ch[id][0]]+cnt[id]+getrank(ch[id][1],v);
}
int getval(int id,int rankk){
    if(id==0) return getval(root,rankk-1);
    if(rankk<=size[ch[id][0]]) return getval(ch[id][0],rankk);
    else if(rankk<=size[ch[id][0]]+cnt[id]) return val[id];
    else return getval(ch[id][1],rankk-size[ch[id][0]]-cnt[id]);
} 
int n,m;
int main(){
    cin.tie(0);
    cout.tie(0);
    build();
    cin>>n>>m;
    for(int i=1;i<=n;i++){
        int x;
        cin>>x;
        insert(root,x);
    }
    int ans=0,last=0;
    while(m--){
        int op,x;
        cin>>op>>x;
        x^=last;
        if(op==1) insert(root,x);
        else if(op==2) remove(root,x);
        else if(op==3){
            last=getrank(root,x)-1;
            ans^=last;
        }
        else if(op==4){
            last=getval(root,x+1);
            ans^=last;
        }
        else if(op==5){
            last=getpre(x);
            ans^=last;
        }
        else if(op==6){
            last=getnext(x);
            ans^=last;
        }
    }
    cout<<ans<<"\n";
    return 0;
}

我也是。。。


|