80分的Splay还有救吗

P6136 【模板】普通平衡树(数据加强版)

Forward_Star @ 2020-02-27 13:51:59

T了两个点啊。

#include<cstdio>
#include<algorithm>
#define INF 2147483647
using namespace std;
    struct splay
    {
        int size,num,v,fa;
        int son[2];
    };
    int n,m,cnt,root;
    splay a[1000001];
int side(int x)
{
    if (a[a[x].fa].son[0] == x)
        return 0;
    else return 1;
}
void pushup(int now)
{
    a[now].size = a[now].num + a[a[now].son[0]].size + a[a[now].son[1]].size;
}
void rotate(int x)
{
    int y = a[x].fa;
    int d = side(x) ^ 1;
    a[y].son[side(x)] = a[x].son[d];
    if (a[x].son[d])
        a[a[x].son[d]].fa = y;
    a[a[y].fa].son[side(y)] = x;
    a[x].fa = a[y].fa;
    a[x].son[d] = y;
    a[y].fa = x;
    pushup(y);
    pushup(x);
}
void splay(int x)
{
    while (a[x].fa != 0)
    {
        int y = a[x].fa;
        int z = a[y].fa;
        if (z != 0)
            if (side(x) == side(y))
                rotate(y);
        rotate(x);
    }
    root = x;
}
void insert(int x)
{
    if (!root)
    {
        root = ++cnt;
        a[cnt].size = 1;
        a[cnt].num = 1;
        a[cnt].v = x;
        return;
    }
    int now = root;
    int last = 0;
    while (now && a[now].v != x)
    {
        last = now;
        a[now].size ++;
        if (x > a[now].v)
            now = a[now].son[1];
        else now = a[now].son[0];
    }
    if (now)
    {
        a[now].size ++;
        a[now].num ++;
        splay(now);
    }
    else
    {
        a[++cnt].size = 1;
        a[cnt].num = 1;
        a[cnt].v = x;
        a[cnt].fa = last;
        if (x < a[last].v)
            a[last].son[0] = cnt;
        else a[last].son[1] = cnt;
        splay(cnt);
    }
}
void remove(int x)
{
    int now = root;
    while (now && a[now].v != x)
    {
        a[now].size --;
        if (x > a[now].v)
            now = a[now].son[1];
        else now = a[now].son[0];
    }
    if (now)
    {
        a[now].size --;
        a[now].num --;
    }
}
/*
int rank(int x)
{
    int now = root;
    while (now && a[now].v != x)
    {
        if (x > a[now].v)
            now = a[now].son[1];
        else now = a[now].son[0];
    }
    if (now)
    {
        splay(now);
        return a[a[now].son[0]].size + 1;
    }
    return 0;
}
*/
int rank(int x)
{
    int now = root;
    int sum = 0;
    while (now)
    {
        if (x > a[now].v)
        {
            sum += a[a[now].son[0]].size + a[now].num;
            now = a[now].son[1];
        }
        else now = a[now].son[0];
    }
    return sum + 1;
}
int find(int x)
{
    int now = root;
    int sum = x;
    while (sum > 0 && now)
    {
        if (a[now].num)
        {
            if (sum <= a[now].num + a[a[now].son[0]].size && sum > a[a[now].son[0]].size)
            {
                splay(now); 
                return a[now].v;
            }
            if (sum <= a[a[now].son[0]].size)
                now = a[now].son[0];
            else if (sum > a[now].num + a[a[now].son[0]].size)
            {
                sum -= a[now].num + a[a[now].son[0]].size;
                now = a[now].son[1];
            }
        }
        else
        {
            if (sum > a[a[now].son[0]].size)
            {
                sum -= a[a[now].son[0]].size;
                now = a[now].son[1];
            }
            else now = a[now].son[0];
        }
    }
    return 0;
}
int lower(int now,int x)
{
    int ans = -INF;
    while (now)
    {
        if (x <= a[now].v)
            now = a[now].son[0];
        else 
        {
            if (a[now].num)
                ans = max(ans,a[now].v);
            else ans = max(ans,lower(a[now].son[0],x));
            now = a[now].son[1];
        }
    }
    return ans;
}
int upper(int now,int x)
{
    int ans = INF;
    while (now)
    {
        if (x >= a[now].v)
            now = a[now].son[1];
        else 
        {
            if (a[now].num)
                ans = min(ans,a[now].v);
            else ans = min(ans,upper(a[now].son[1],x));
            now = a[now].son[0];
        }
    }
    return ans;
}
int read()
{
    char c = getchar();
    while (c < '0' || c > '9')
        c = getchar();
    int x = 0;
    while (c >= '0' && c <= '9')
    {
        x = x * 10 + c - '0';
        c = getchar();
    }
    return x;
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i = 1;i <= n;i ++)
    {
        int x;
        x = read();
        insert(x);
    }
    int ans = 0;
    int t = 0;
    for (int i = 1;i <= m;i ++)
    {
        int opt,x;
        opt = read();
        x = read();
        x ^= ans;
        if (opt == 1)
            insert(x);
        else if (opt == 2)
            remove(x);
        else
        {
            if (opt == 3)
            {
                // insert(x);
                ans = rank(x);
                // remove(x);
            }
            if (opt == 4)
                ans = find(x);
            if (opt == 5)
                ans = lower(root,x);
            if (opt == 6)
                ans = upper(root,x);
            t ^= ans;
        }
    }
    printf("%d",t);
    return 0;
}

by 狂气の月兔 @ 2020-02-27 14:04:03

快不了。。


by panyf @ 2020-02-27 15:14:52

@Forward_Star

有救的,首先空间要开到110w,然后最上面两个函数去掉,直接写在后面的代码里面,减小常数。快读不加也能过。


by Fuyuki @ 2020-02-27 15:49:20

每 500 次操作就将中位数 splay 一次


上一页 |