splay TLE#10 96pts求助

P6136 【模板】普通平衡树(数据加强版)

TernaryTree @ 2022-09-17 12:06:59

rt,本机跑非常慢。

#include <bits/stdc++.h>
#define reg register 

using namespace std;

struct ios {
    inline char read() {
        static const int inlen = 1 << 18 | 1;
        static char buf[inlen], *s, *t;
        return (s == t) && (t = (s = buf) + fread(buf, 1, inlen, stdin)), s == t ? -1 : *s++;
    }
    template<typename T> inline ios& operator>> (T &x) {
        static char c11, boo;
        for (c11 = read(), boo = 0; !isdigit(c11); c11 = read()) {
            if (c11 == -1) return *this;
            boo |= c11 == '-';
        }
        for (x = 0; isdigit(c11); c11 = read()) x = x * 10 + (c11 ^ '0');
        boo && (x = -x);
        return *this;
    }
} fin;

struct exios {
    template<typename _CharT, typename _Traits = char_traits<_CharT>>
    struct typ {
        typedef basic_ostream<_CharT, _Traits>& (* end) (basic_ostream<_CharT, _Traits>&);
    };

    friend exios &operator<<(exios &out, int num) {
        if (num < 0) putchar('-'), num = -num;
        if (num >= 10) out << num / 10;
        putchar(num % 10 + '0');
        return out;
    }

    friend exios &operator<<(exios &out, const char * s) { printf("%s", s); return out; }
    friend exios &operator<<(exios &out, string s) { cout << s; return out; }
    friend exios &operator<<(exios &out, typ<char>::end e) { puts(""); return out; }
} fout;

struct splayTree {
    struct node {
        node * parent;
        node * child[2];
        int value, count, size;

        node(int val): value(val) {
            parent = child[0] = child[1] = nullptr;
            count = size = 1;
        }
    };

    node * root;

    splayTree() {
        root = nullptr;
    }

    ~splayTree() {
        destroy(root);
    }

    void destroy(node * cur) {
        if (cur) {
            destroy(cur->child[0]);
            destroy(cur->child[1]);
            delete cur;
        }
    }

    void update(node * cur) {
        if (cur == nullptr) {
            return;
        }
        cur->size = cur->count;
        if (cur->child[0]) {
            cur->size += cur->child[0]->size;
        }
        if (cur->child[1]) {
            cur->size += cur->child[1]->size;
        }
    }

    int get(node * cur) {
        if (cur == nullptr || cur->parent == nullptr) {
            return -1;
        }
        return cur->parent->child[1] == cur;
    }

    void connect(node * parent, node * cur, int type) {
        if (parent) {
            parent->child[type] = cur;
        }
        if (cur) {
            cur->parent = parent;
        }
    }

    void rotate(node * cur) {
        if (cur == root || cur == nullptr) {
            return;
        }
        node * parent = cur->parent;
        node * grandparent = parent->parent;
        int type = get(cur);
        int parent_type = get(parent);

        connect(parent, cur->child[type ^ 1], type);
        connect(cur, parent, type ^ 1);
        if (parent == root) {
            root = cur;
        }
        connect(grandparent, cur, parent_type);

        update(parent);
        update(cur);
    }

    void splay(node * cur) {
        if (cur == nullptr || cur->parent == nullptr) {
            return;
        }
        while (cur->parent->parent) {
            if (get(cur) == get(cur->parent)) {
                rotate(cur->parent);
                continue;
            }
            rotate(cur);
        }
        rotate(cur);
    }

    node * find(int value) {
        node * cur = root;
        while (cur) {
            if (cur->value < value) {
                cur = cur->child[1];
            } else if (cur->value > value) {
                cur = cur->child[0];
            } else {
                splay(cur);
                return cur;
            }
        }
        return nullptr;
    }

    int findrank(int value) {
        node * cur = root;
        int size = 0;
        while (cur) {
            if (value < cur->value) {
                cur = cur->child[0];
            } else {
                if (cur->child[0]) {
                    size += cur->child[0]->size;
                }

                if (value == cur->value) {
                    size += 1;
                    splay(cur);
                    return size;
                }

                size += cur->count;
                cur = cur->child[1];
            }
        }
        return -1;
    }

    node * getrank(int rank) {
        node * cur = root;
        while (cur) {
            if ((cur->child[0] && rank > cur->child[0]->size) || cur->child[0] == nullptr) {
                if (cur->child[0]) {
                    rank -= cur->child[0]->size;
                }
                rank -= cur->count;

                if (rank <= 0) {
                    splay(cur);
                    return cur;
                }
                cur = cur->child[1];
            } else {
                cur = cur->child[0];
            }
        } 
        return nullptr;
    }

    node * pre_suf(int type, node * cur) {
        if (cur == nullptr) {
            return nullptr;
        }

        splay(cur);
        cur = cur->child[type];
        while (cur && cur->child[type ^ 1]) {
            cur = cur->child[type ^ 1];
        }
        splay(cur);
        return cur;
    }

    node * insert(int value) {
        if (root == nullptr) {
            root = new node(value);
            return root;
        }
        node * cur = root;
        node * parent = cur->parent;
        int type;
        while (cur) {
            if (cur->value < value) {
                parent = cur;
                cur = cur->child[1];
                type = 1;
            } else if (cur->value > value) {
                parent = cur;
                cur = cur->child[0];
                type = 0;
            } else {
                cur->count++;
                splay(cur);
                return cur;
            }
        }
        cur = new node(value);
        connect(parent, cur, type);
        splay(cur);
        return cur;
    }

    node * min_max(int type, node * cur) {
        while (cur->child[type]) {
            cur = cur->child[type];
        }
        splay(cur);
        return cur;
    }

    void remove(node * cur) {
        splay(cur);
        if (cur->count >= 2) {
            cur->count--;
            return;
        }
        node * left = cur->child[0];
        node * right = cur->child[1];
        if (left) {
            left->parent = nullptr;
        }
        if (right) {
            right->parent = nullptr;
        }
        delete cur;
        if (left && right) {
            left = min_max(1, left);
            right = min_max(0, right);
            connect(left, right, 1);
            root = left;
        } else {
            if (left) {
                root = left;
            } else {
                root = right;
            }
        }
    } 
};

splayTree tree;
int n, m, ai, op, num, ans;

int main() {
    int last = 0;
    fin >> n >> m;
    for (reg int i = 0; i < n; ++i) {
        fin >> ai;
        tree.insert(ai);
    }
    for (reg int i = 0; i < m; ++i) {
        fin >> op >> num;
        num ^= last;
        if (op == 1) {
            tree.insert(num);
        } else if (op == 2) {
            tree.remove(tree.find(num));
        } else if (op == 3) {
            tree.insert(num);
            last = tree.findrank(num);
            tree.remove(tree.find(num));
            ans ^= last;
        } else if (op == 4) {
            last = tree.getrank(num)->value;
            ans ^= last;
        } else if (op == 5) {
            tree.insert(num);
            last = tree.pre_suf(0, tree.find(num))->value;
            ans ^= last;
            tree.remove(tree.find(num));
        } else if (op == 6) {
            tree.insert(num);
            last = tree.pre_suf(1, tree.find(num))->value;
            ans ^= last;
            tree.remove(tree.find(num));
        }
    }
    fout << ans << endl;
    return 0;
}

by d0j1a_1701 @ 2022-09-17 12:50:35

@ternary_tree 试试把指针换掉


|