Splay 24pts求调

P6136 【模板】普通平衡树(数据加强版)

Nullity_Silence @ 2024-07-31 21:16:13

WA,TLE都有,TLE不是因为数组没开够

#include<iostream>
#include<algorithm>
#include<cstring>

using namespace std;

const int N=2e6+10;
const int inf=0x3f3f3f3f;

struct node
{
    int ch[2],fa,value,cnt,size;
};

int n,m,root,top;
node tree[N];

int get(int x)
{
    return x==tree[tree[x].fa].ch[1];
}

void pushup(int x)
{
    if(x==0)
        return;
    tree[x].size=tree[tree[x].ch[0]].size+tree[tree[x].ch[1]].size+tree[x].cnt;
}

void rotate(int x)
{
    int y=tree[x].fa;
    int z=tree[y].fa;
    int chk=get(x);
    tree[y].ch[chk]=tree[x].ch[chk^1];
    if(tree[x].ch[chk^1])
        tree[tree[x].ch[chk^1]].fa=y;
    tree[x].ch[chk^1]=y;
    tree[y].fa=x;
    if(z)
        tree[z].ch[y == tree[z].ch[1]]=x;
    tree[x].fa=z;
    pushup(y);
    pushup(x);
    return;
}

void splay(int x,int k)
{
    while(tree[x].fa!=k)
    {
        int y=tree[x].fa;
        int z=tree[y].fa;
        if(z!=k)
        {
            if(get(x)==get(y))
                rotate(y);
            else
                rotate(x);
        }
        rotate(x);
    }
    if(!k)
        root=x;
    return;
}

void insert(int x)
{
    int cur=root;
    int fa=0;
    while(cur)
    {
        if(tree[cur].value==x)
        {
            tree[cur].cnt++;
            pushup(cur);
            pushup(fa);
            splay(cur,0);
            return;
        }
        fa=cur;
        cur=tree[cur].ch[x>tree[cur].value];
    }
    cur=++top;
    tree[cur].value=x;
    tree[cur].fa=fa;
    tree[fa].ch[x>tree[fa].value]=cur;
    tree[cur].cnt=1;
    pushup(cur);
    pushup(fa);
    splay(cur,0);
    return;
}

int kth(int x)
{
    int cur=root;
    int res=0;
    while(cur)
    {
        if(x<tree[cur].value)
            cur=tree[cur].ch[0];
        else
        {
            res+=tree[tree[cur].ch[0]].size;
            if(x==tree[cur].value)
            {
                splay(cur,0);
                return res+1;
            }
            else
            {
                res+=tree[cur].cnt;
                cur=tree[cur].ch[1];
            }
        }
    }
    return res+1;
}

int rnk(int x)
{
    int cur=root;
    while(cur)
    {
        if(x<=tree[tree[cur].ch[0]].size)
            cur=tree[cur].ch[0];
        else if(x<=tree[tree[cur].ch[0]].size+tree[cur].cnt)
        {
            splay(cur,0);
            return cur;
        }
        else
        {
            x-=tree[tree[cur].ch[0]].size+tree[cur].cnt;
            cur=tree[cur].ch[1];
        }
    }
    return -1;
}

//pre:查询x的前驱(定义为小于x的最大的数)
int pre(int x)
{
    int cur=root;
    int res=-inf;
    while(cur)
    {
        if(x>tree[cur].value)
        {
            res=max(res,tree[cur].value);
            cur=tree[cur].ch[1];
        }
        else
            cur=tree[cur].ch[0];
    }
    return res;
}

int nxt(int x)
{
    int cur=root;
    int res=inf;
    while(cur)
    {
        if(x<tree[cur].value)
        {
            res=min(res,tree[cur].value);
            cur=tree[cur].ch[0];
        }
        else
            cur=tree[cur].ch[1];
    }
    return res;
}

void find(int x)
{
    int cur=root;
    while(cur)
    {
        if(x==tree[cur].value)
        {
            splay(cur,0);
            return;
        }
        if(x<tree[cur].value)
            cur=tree[cur].ch[0];
        else
            cur=tree[cur].ch[1];
    }
    return;
}

void del(int x)
{
    find(x);
    int l=tree[root].ch[0];
    int r=tree[root].ch[1];
    while(tree[l].ch[1])
        l=tree[l].ch[1];
    while(tree[r].ch[0])
        r=tree[r].ch[0];
    splay(l,0);
    splay(r,l);
    if(tree[tree[r].ch[0]].cnt>1)
    {
        tree[tree[r].ch[0]].cnt--;
        pushup(tree[r].ch[0]);
    }
    else
    {
        tree[r].ch[0]=0;
        pushup(tree[r].ch[0]);
    }
    pushup(r);
    pushup(l);
    return;
}

int main()
{
    std::ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin>>n>>m;
    insert(inf);
    insert(-inf);
    for(int i=1;i<=n;i++)
    {
        int a;
        cin>>a;
        insert(a);
    }
    int res=0;
    int last=0;
    for(int i=1;i<=m;i++)
    {
        int pos,x;
        cin>>pos>>x;
        switch (pos)
        {
            case 1:
                insert(x^last);
                break;
            case 2:
                del(x^last);
                break;
            case 3:
                last=(kth(x^last)-1);
                res^=last;
                break;
            case 4:
                last=tree[rnk((x^last)+1)].value;
                res^=last;
                break;
            case 5:
                last=pre(x^last);
                res^=last;
                break;
            case 6:
                last=nxt(x^last);
                res^=last;
                break;
            default:
                break;
        }
    }
    cout<<res<<endl;
    return 0;
}

|