Splay 96pts TLE on #19 求调

P6136 【模板】普通平衡树(数据加强版)

modfisher @ 2023-12-11 21:13:06

rt

#include <bits/stdc++.h>

using namespace std;

const int maxn = 2e6 + 5;

struct Splay{
    int val, cnt, ch[2], fa, siz;
}sp[maxn];
int root = 0, tot = 0;

void up(int x){
    sp[x].siz = sp[sp[x].ch[0]].siz + sp[sp[x].ch[1]].siz + sp[x].cnt;
}
int lchild(int x){
    return x == sp[sp[x].fa].ch[0];
}
void clear(int x){
    sp[x].ch[0] = sp[x].ch[1] = sp[x].cnt = sp[x].fa = sp[x].siz = sp[x].val = 0;
}
void rotate(int x){
    int y = sp[x].fa, z = sp[y].fa, xl = lchild(x), yl = lchild(y);
    if(!y) return;
    sp[y].ch[!xl] = sp[x].ch[xl];
    if(sp[x].ch[xl]) sp[sp[x].ch[xl]].fa = y;
    sp[y].fa = x;
    sp[x].ch[xl] = y;
    sp[x].fa = z;
    sp[z].ch[!yl] = x;
    up(x);
    up(y);
}
void splay(int x){
    for(int i = sp[x].fa; i; i = sp[x].fa){
        if(sp[i].fa){
            if(lchild(x) ^ lchild(i)) rotate(x);
            else rotate(i);
        }
        rotate(x);
    }
    root = x;
}
void insert(int x){
    if(!root){
        root = ++ tot;
        sp[tot].val = x;
        sp[tot].cnt ++;
        up(tot);
        return;
    }
    int now = root, pr = 0;
    while(1){
        if(sp[now].val == x){
            sp[now].cnt ++;
            up(now);
            if(pr) up(pr);
            splay(now);
            return;
        }
        pr = now;
        now = sp[now].ch[x > sp[now].val];
        if(!now){
            now = ++ tot;
            sp[now].fa = pr;
            sp[now].val = x;
            sp[now].cnt ++;
            sp[pr].ch[x > sp[pr].val] = now;
            up(now);
            up(pr);
            splay(now);
            return;
        }
    }
}
int ranking(int x){
    int now = root, res = 0;
    while(1){
        if(x < sp[now].val) now = sp[now].ch[0];
        else{
            res += sp[sp[now].ch[0]].siz;
            if(x == sp[now].val){
                splay(now);
                return res + 1;
            }
            res += sp[now].cnt;
            now = sp[now].ch[1];
        }
    }
}
int kth(int x){
    int now = root;
    while(1){
        if(sp[now].ch[0] && sp[sp[now].ch[0]].siz >= x) now = sp[now].ch[0];
        else{
            x -= sp[sp[now].ch[0]].siz + sp[now].cnt;
            if(x <= 0){
                splay(now);
                return sp[now].val;
            }
            now = sp[now].ch[1];
        }
    }
}
int pre(){
    if(!sp[root].ch[0]) return 0;
    int now = sp[root].ch[0];
    while(sp[now].ch[1]) now = sp[now].ch[1];
    splay(now);
    return now;
}
int nxt(){
    if(!sp[root].ch[1]) return 0;
    int now = sp[root].ch[1];
    while(sp[now].ch[0]) now = sp[now].ch[0];
    splay(now);
    return now;
}
void del(int x){
    ranking(x);
    if(sp[root].val != x) return;
    if(sp[root].cnt > 1){
        sp[root].cnt --;
        up(root);
        return;
    }
    if(!sp[root].ch[0] && !sp[root].ch[1]){
        clear(root);
        root = 0;
        return;
    }
    int oro = root;
    if(!sp[root].ch[0]){
        root = sp[root].ch[1];
        clear(oro);
        return;
    }
    if(!sp[root].ch[1]){
        root = sp[root].ch[0];
        clear(oro);
        return;
    }
    root = pre();
    sp[root].ch[1] = sp[oro].ch[1];
    if(sp[oro].ch[1]) sp[sp[oro].ch[1]].fa = root;
    up(root);
    clear(oro);
}
int read(){
    int x = 0;
    char c = getchar();
    while(c < '0' || c > '9'){
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        x = (x << 3) + (x << 1) + c - 48;
        c = getchar();
    }
    return x;
}

int main(){
    int n = read(), m = read();
    for(int i = 1; i <= n; i ++){
        int a = read();
        insert(a);
    }
    int lans = 0, rans = 0;
    while(m --){
        int op = read(), x = read();
        x ^= lans;
        if(op == 1){
            insert(x);
        }else if(op == 2){
            del(x);
        }else if(op == 3){
            insert(x);
            lans = ranking(x);
            rans ^= lans;
            del(x);
        }else if(op == 4){
            lans = kth(x);
            rans ^= lans;
        }else if(op == 5){
            insert(x);
            lans = sp[pre()].val;
            rans ^= lans;
            del(x);
        }else{
            insert(x);
            lans = sp[nxt()].val;
            rans ^= lans;
            del(x);
        }
    }
    printf("%d\n", rans);
    return 0;
}

by PNNNN @ 2023-12-11 22:55:19

%%%


|