WBLT 72pts 求助

P6136 【模板】普通平衡树(数据加强版)

xuyiyang @ 2024-08-24 13:44:26

WA 13,15,20,23 MLE 19 码风良好的 WBLT 与 oi-wiki 的类似。求调 注意:代码中的 N 开了 2.3\times 10^6,包括垃圾回收

int n, m, v[N];

// WBLT
const double alpha = 0.25;
int q[N], hh = 0, tt = 0;
int w[N], ls[N], rs[N], idx;
int sz[N], id; 

void pushup(int x) { if (!ls[x] && !rs[x]) return ; sz[x] = sz[ls[x]] + sz[rs[x]]; w[x] = w[rs[x]]; }
int get(int v) {  
    if (hh != tt) { id = q[hh ++ ]; if (hh == N) hh = 0; }
    else id = ++ idx;
    w[id] = v, ls[id] = rs[id] = 0, sz[id] = 1;
    return id;
}
int build(int l, int r) {
    if (l == r) return get(v[l]);
    int mid = l + r >> 1, x = get(0);
    ls[x] = build(l, mid), rs[x] = build(mid + 1, r);
    return pushup(x), x;
}
int z;
int merge(int x, int y) {
    z = get(w[y]); ls[z] = x, rs[z] = y;
    return pushup(z), z;
}
void rotate(int x, int y) { // 0 left 1 right
    if (!y) {
        rs[x] = merge(rs[ls[x]], rs[x]);
        ls[x] = ls[ls[x]];
    } else {
        ls[x] = merge(ls[x], ls[rs[x]]);
        rs[x] = rs[rs[x]];
    }
}
void maintain(int x) {
    if (!ls[x] && !rs[x]) return ;
    if (sz[ls[x]] > sz[rs[x]]) {
        if (sz[rs[x]] >= sz[x] * alpha) return ;
        if (sz[rs[ls[x]]] >= sz[ls[x]] * (1 - 2 * alpha) / (1 - alpha)) rotate(ls[x], 1);
        rotate(x, 0);
    } else {
        if (sz[ls[x]] >= sz[x] * alpha) return ;
        if (sz[ls[rs[x]]] >= sz[rs[x]] * (1 - 2 * alpha) / (1 - alpha)) rotate(rs[x], 0);
        rotate(x, 1); 
    }
}
void cpy(int x, int y) { ls[x] = ls[y], rs[x] = rs[y], sz[x] = sz[y], w[x] = w[y]; }
void ins(int v, int x) {
    if (!ls[x] && !rs[x]) {
        ls[x] = get(min(v, w[x])), rs[x] = get(max(v, w[x]));
        pushup(x); return maintain(x), void();
    }
    if (v <= w[ls[x]]) ins(v, ls[x]); else ins(v, rs[x]);
    pushup(x); maintain(x);
}
void del(int v, int x, int fa) {
    if (!ls[x] && !rs[x]) {
        if (ls[fa] == x) cpy(fa, rs[fa]); else cpy(fa, ls[fa]);
        q[tt ++ ] = x; ls[x] = rs[x] = w[x] = sz[x] = 0;
        if (tt == N) tt = 0;
        return ;
    }
    if (v <= w[ls[x]]) del(v, ls[x], x); else del(v, rs[x], x); 
    pushup(x); maintain(x);
}
int rk(int v, int x) {
    if (!ls[x] && !rs[x]) return 1;
    if (v <= w[ls[x]]) return rk(v, ls[x]);
    return sz[ls[x]] + rk(v, rs[x]);
}
int kth(int v, int x) {
    if (!ls[x] && !rs[x]) return w[x];
    if (v <= sz[ls[x]]) return kth(v, ls[x]);
    return kth(v - sz[ls[x]], rs[x]);
}

bool Med;
int main() {
    int res = 0, lst = 0;
    rd(n), rd(m);
    for (int i = 1; i <= n; i ++ ) rd(v[i]); sort(v + 1, v + 1 + n); build(1, n);
    while (m -- ) {
        int op, x; rd(op), rd(x); x ^= lst;
        if (op == 1) ins(x, 1);
        else if (op == 2) del(x, 1, 0);
        else if (op == 3) lst = rk(x, 1), res ^= lst;
        else if (op == 4) lst = kth(x, 1), res ^= lst;
        else if (op == 5) lst = kth(rk(x, 1) - 1, 1), res ^= lst;
        else lst = kth(rk(x + 1, 1), 1), res ^= lst;
    } printf("%d\n", res);
    return 0;
}

by lao_wang @ 2024-08-24 13:49:25

@OIer_tan


by wanchenhao @ 2024-08-24 13:55:13

@xuyiyang

AC代码求关

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 2e6+10;
struct Node{
    int son[2],fa;
    int siz,val;
}tr[maxn];
int tot,rt;
int newnode(int x,int fa){
    tr[++tot]={{0,0},fa,1,x};
    return tot;
}
void push_up(int p){
    tr[p].siz=tr[tr[p].son[0]].siz+tr[tr[p].son[1]].siz+1;
}
bool get(int p,int fa){
    return p==tr[fa].son[1];
}
void clear(int p){
    tr[p]={{0,0},0,0,0};
}
void connect(int x,int fa,int k){
    tr[x].fa=fa;
    tr[fa].son[k]=x;
}
void rotate(int p){
    int fa=tr[p].fa,ffa=tr[fa].fa,op=get(p,fa);
    connect(tr[p].son[op^1],fa,op);
    connect(p,ffa,get(fa,ffa));
    connect(fa,p,op^1);
    push_up(fa);
    push_up(p);
}
void splay(int p,int goal=0){
    if(goal==0) rt=p;
    while(tr[p].fa!=goal){
        int fa=tr[p].fa,ffa=tr[fa].fa;
        if(ffa!=goal)
            rotate(get(p,fa)==get(fa,ffa)?fa:p);
        rotate(p);
    }
}
void Insert(int &p,int fa,int w){
    if(!p) splay(p=newnode(w,fa));
    else Insert(tr[p].son[w>=tr[p].val],p,w);
}
void delnode(int p){
    splay(p);
    if(tr[p].son[1]){
        int x=tr[p].son[1];
        while(tr[x].son[0]) x=tr[x].son[0];
        splay(x,p);
        connect(tr[p].son[0],x,0);
        rt=x;
        tr[rt].fa=0;
        push_up(rt);
    }else{
        rt=tr[p].son[0];
        tr[rt].fa=0;
    }
}
void Del(int &p,int x){
    if(tr[p].val==x) delnode(p);
    else Del(tr[p].son[x>=tr[p].val],x);
}
int Rank(int x){
    int p=rt,res=1,fa=0;
    while(p){
        if(x<=tr[p].val){
            fa=p;
            p=tr[p].son[0];
        }else{
            res+=tr[tr[p].son[0]].siz+1;
            p=tr[p].son[1];
        }
    }
    if(fa) splay(fa);
    return res;
}
int kth(int k){
    int p=rt;
    while(1){
        if(tr[p].son[0]&&tr[tr[p].son[0]].siz>=k){
            p=tr[p].son[0];
        }else{
            k-=tr[tr[p].son[0]].siz+1;
            if(k<=0){
                splay(p);
                return tr[p].val;
            }
            p=tr[p].son[1];
        }
    }
}
int pre(int x){
    return kth(Rank(x)-1);
}
int suf(int x){
    return kth(Rank(x+1));
}
int n,m,res,lastans;
int main(){
    scanf("%d%d",&n,&m);
    for(int i=1,x;i<=n;++i){
        scanf("%d",&x);
        Insert(rt,0,x);
    }
    while(m--){
        int op,x;
        scanf("%d%d",&op,&x);
        x^=lastans;
        if(op==1) Insert(rt,0,x);
        if(op==2) Del(rt,x);
        if(op==3) res^=lastans=Rank(x);
        if(op==4) res^=lastans=kth(x);
        if(op==5) res^=lastans=pre(x);
        if(op==6) res^=lastans=suf(x);
    }
    printf("%d",res);
    return 0;
}

by OldDriverTree @ 2024-08-24 14:08:04

@wanchenhao 你看不见 lz 求助的是 WBLT 吗?lz 让调代码,你发个 AC 代码是什么意思?


by xuyiyang @ 2024-08-24 15:16:08

@wanchenhao 大哥这是 Splay


by lg10 @ 2024-10-05 12:48:37

一开始插入的序列之后可能会被删完,所以还是需要在一开始插入inf

我也是WBLT72pts(


|