Splay 30分求助

P6136 【模板】普通平衡树(数据加强版)

秋木弦 @ 2021-12-07 17:21:43

评测结果 RT


#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pll;
typedef pair<int, int> pii;
#define all(x) x.begin(), x.end()
#define maxi(x) max_element(x.begin(), x.end()) - x.begin()
#define mini(x) min_element(x.begin(), x.end()) - x.begin()
#define int long long
const int maxn = 2e6 + 10, inf = 0x3f3f3f3f;
int n, m;
struct node
{
    int son[2], p, v;
    int size, cnt, flag;

    void init(int _v, int _p)
    {
        v = _v, p = _p;
        size = cnt = 1, flag = 0;
    }
} tr[maxn];
int root, idx;

void push_up(int u)
{
    tr[u].size = tr[tr[u].son[1]].size + tr[tr[u].son[0]].size + tr[u].cnt;
}

void rotate(int x)
{
    int y = tr[x].p, z = tr[y].p;
    int k = (tr[y].son[1] == x); //k=0 :k是y的左儿子
    tr[z].son[tr[z].son[1] == y] = x, tr[x].p = z;
    tr[y].son[k] = tr[x].son[k ^ 1], tr[tr[x].son[k ^ 1]].p = y;
    tr[x].son[k ^ 1] = y, tr[y].p = x;
    push_up(y), push_up(x);
}

void splay(int x, int k)
{
    while (tr[x].p != k)
    {
        int y = tr[x].p, z = tr[y].p;
        if (z != k)
            if ((tr[y].son[1] == x) ^ (tr[z].son[1] == y))
                rotate(x);
            else
                rotate(y);
        rotate(x);
    }
    if (k == 0)
        root = x;
}

void Find(int x)
{
    int u = root;
    if (!u)
        return;
    // while (tr[u].son[x > tr[u].v] &&  tr[u].v)
    while (tr[u].son[x > tr[u].v] && x != tr[u].v)
        u = tr[u].son[x > tr[u].v];
    splay(u, 0);
}

int Next(int x, int f)
{
    Find(x);
    int u = root;
    if (tr[u].v > x && f)
        return u;
    if (tr[u].v < x && !f)
        return u;
    u = tr[u].son[f];
    while (tr[u].son[f ^ 1])
        u = tr[u].son[f ^ 1];
    return u;
}

int get_pre(int x)
{
    return Next(x, 0);
}

int get_nex(int x)
{
    return Next(x, 1);
}

void insert(int v)
{
    int u = root, p = 0;
    while (u && tr[u].v != v)
    {
        p = u, u = tr[u].son[v > tr[u].v];
    }
    if (u)
        tr[u].cnt++;
    else
    {
        u = ++idx;
        if (p)
            tr[p].son[v > tr[p].v] = u;
        tr[u].init(v, p);
    }
    splay(u, 0);
}
void output(int u)
{
    // push_down(u);
    if (tr[u].son[0])
        output(tr[u].son[0]);
    cout << tr[u].v << " ";
    if (tr[u].son[1])
        output(tr[u].son[1]);
}

void remove(int x)
{
    int last = get_pre(x);
    int next = get_nex(x);
    // cout << "......" << last << endl;
    splay(last, 0);
    splay(next, last);
    int del = tr[next].son[0];
    if (tr[del].cnt > 1)
    {
        tr[del].cnt--;
        splay(del, 0);
    }
    else
        tr[next].son[0] = 0;
}

int get_k(int x)
{
    int u = root;
    if (tr[u].size < x)
        return 0;
    while (true)
    {
        int y = tr[u].son[0];
        if (x > tr[y].size + tr[u].cnt)
        {
            x -= tr[y].size + tr[u].cnt;
            u = tr[u].son[1];
        }
        else if (tr[y].size >= x)
            u = y;
        else
            return tr[u].v;
    }
    return -1;
}

int get_rank(int x)
{
    Find(x);
    if (tr[root].v == x)
        return tr[tr[root].son[0]].size;
    return tr[tr[root].son[0]].size + tr[root].cnt;
}

void out(int u)
{
    cout << "u=" << u << " " << tr[u].v << " " << tr[u].size << " " << tr[u].cnt << " " << tr[u].son[0] << " " << tr[u].son[1] << endl;
}

void outputf(int u)
{
    if (!u)
        return;
    queue<int> q;
    q.push(u);
    while (!q.empty())
    {
        u = q.front();
        q.pop();
        out(u);
        if (tr[u].son[0])
            q.push(tr[u].son[0]);
        if (tr[u].son[1])
            q.push(tr[u].son[1]);
    }
    cout << "---------" << endl;
}

void solve()
{
    cin >> n >> m;
    insert(0), insert(inf);
    // out(root);
    // out(tr[root].son[0]);
    // outputf(root);

    for (int i = 1; i <= n; i++)
    {
        int x;
        cin >> x;
        insert(x);
    }
    // outputf(root);
    // output(root), cout << endl;
    int last = 0, ans = 0;
    while (m--)
    {
        int op, x;
        cin >> op >> x;
        x ^= last;
        if (op == 1)
            insert(x);
        else if (op == 2)
            remove(x);
        else if (op == 3)
        {
            last = get_rank(x);
            // cout << last << endl;
            ans ^= last;
        }
        else if (op == 4)
        {
            last = get_k(min(x + 1, tr[root].size - 1));
            // cout << last << endl;
            ans ^= last;
        }
        else if (op == 5)
        {
            last = tr[get_pre(x)].v;
            // cout << last << endl;
            ans ^= last;
        }
        else
        {
            last = tr[get_nex(x)].v;
            // cout << last << endl;
            ans ^= last;
        }
        // outputf(root);
    }
    cout << ans << endl;
}

signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    solve();
    return 0;
}

|