求助splay40ptsTLE

P6136 【模板】普通平衡树(数据加强版)

Gary88 @ 2020-12-12 17:00:48

#include<iostream>
#include<cstdio>
#include<ctime>
#include<algorithm>
using namespace std;
int mode,n,m,rt,cnt,last,totans;
double t1,t2;
struct node
{
    int ch[2],sz,tot,w,f;
}t[5000001];
void update(int x)
{
    t[x].tot=t[x].sz+t[t[x].ch[0]].tot+t[t[x].ch[1]].tot;
}
void rotate(int x)
{
    int y=t[x].f,z=t[y].f,k=(t[y].ch[1]==x);
    t[z].ch[(t[z].ch[1]==y)]=x;
    t[x].f=z;
    t[y].ch[k]=t[x].ch[k^1];
    t[t[x].ch[k^1]].f=y;
    t[x].ch[k^1]=y;
    t[y].f=x;
    update(y),update(x);
}
void splay(int x,int root)
{
    int ans=0;
    while(t[x].f!=root)
    {
        int y=t[x].f,z=t[y].f;
        if(z!=root)
        {
            if((t[z].ch[1]==y)^(t[y].ch[1]==x)) rotate(x);
            else rotate(y);
        }
        rotate(x);
    }
    if(!root)
    rt=x;
    else
    t[rt].ch[t[rt].w<t[x].w]=x;
}
void find(int x)
{
    int u=rt;
    while(t[u].ch[(t[u].w<x)]&&t[u].w!=x)
        u=t[u].ch[(t[u].w<x)];
    splay(u,0);
}
void insert(int x)
{
    int u=rt,fa=0;
    while(u&&x!=t[u].w)
    {
        fa=u;
        u=t[u].ch[(t[u].w<x)];
    }
    if(!u)
    {
        cnt++;
        if(!rt)
        rt=cnt;
        else
        t[fa].ch[(t[fa].w<x)]=cnt;
        t[cnt].w=x,t[cnt].sz=t[cnt].tot=1,t[cnt].f=fa;  
        splay(cnt,0);
    }
    else
    {
        t[u].sz++,t[u].tot++;
        splay(u,0);
    }
}
int low(int x)
{
    find(x);
    if(t[rt].w<x)
        return rt;
    int u=t[rt].ch[0],fa=rt;
    while(u)
    {
        fa=u;
        u=t[u].ch[1];
    }
    if(fa==rt)
    return -1;
    splay(fa,0);
    return fa;
}
int high(int x)
{
    find(x);
//  printf("%d\n",rt);
    if(t[rt].w>x)
        return rt;
    int u=t[rt].ch[1],fa=rt;
    while(u)
    {
        fa=u;
        u=t[u].ch[0];
    }
    if(fa==rt)
    return -1;
    splay(fa,0);
    return fa;
}
void del(int x)
{
    int y=low(x);
    int z=high(x);
    if(y!=-1&&z!=-1)
    {
        splay(y,0);
        splay(z,rt);
        int u=t[z].ch[0];
        if(t[u].sz==1)
        t[z].ch[0]=0;
        else
        {
            t[u].sz--,t[u].tot--;
            splay(u,0);
        }
    }
    else if(y==-1&&z!=-1)
    {
        splay(z,0);
        int u=t[z].ch[0];
        if(t[u].sz==1)
        t[z].ch[0]=0;
        else
        {
            t[u].sz--,t[u].tot--;
            splay(u,0);
        }
    }
    else if(y!=-1&&z==-1)
    {
        splay(y,0);
        int u=t[y].ch[1];
        if(t[u].sz==1)
        t[y].ch[1]=0;
        else
        {
            t[u].sz--,t[u].tot--;
            splay(u,0);
        }
    }
    else
    rt=0;
}
int find1(int x)
{
    find(x);
    if(t[rt].w<x)
        return t[t[rt].ch[0]].tot+2;
    else
        return t[t[rt].ch[0]].tot+1;
}
int find2(int x)
{
    if(x>t[rt].tot)
    return -1;
    int u=rt;
    while(1)
    {
        if(t[t[u].ch[0]].tot>=x)
        u=t[u].ch[0];
        else if(t[t[u].ch[0]].tot+t[u].sz<x)
        x-=t[t[u].ch[0]].tot+t[u].sz,u=t[u].ch[1];
        else
        break;
    }
    splay(u,0);
    return t[u].w;
}
int main()
{
    scanf("%d%d",&n,&m);
    while(n--)
    {
        int x;
        scanf("%d",&x);
        insert(x);
    }
    while(m--)
    {
        int x,xx;
        scanf("%d%d",&mode,&x);
        x^=last;
        switch(mode)
        {
            case 1:
                insert(x);
                break;
            case 2:
                del(x);
                break;
            case 3:
                last=find1(x);
                totans^=last;
                break;
            case 4:
                last=find2(x);
                totans^=last;
                break;
            case 5:
                xx=low(x);
                last=t[xx].w;
                totans^=last;
                break;
            case 6:
                xx=high(x);
                last=t[xx].w;
                totans^=last;
                break;
        }
    }
    printf("%d",totans);
    return 0;
}

|