用字典树实现平衡树

masonxiong

2024-11-07 19:12:50

Algo. & Theory

引入

在 OI 里面我们经常会遇到一些需要用到平衡树的题。虽然其中大部分题都可以用 std::set 解决,但是总有那么一些东西是 std::set 解决不了的,比如查排名之类的。

那么这个时候,我们要么用 pbds 里面的 tree 要么自己写平衡树。但 pbds 里面的平衡树用起来非常麻烦,要记一堆什么乱七八糟的玩意,还不好调。那么我们最好的方案便是自己写一个平衡树。

但平衡树的码量是众所周知的大,动不动就上百行,而且一般都是写半小时调两小时,写完平衡树考试都结束了。再加上大部分平衡树维护平衡的方式都非常难记,各种左右乱转和分类讨论完全写不了。

为了解决这个问题,我们可以采取一种全新的方法:使用字典树(trie)来写平衡树

中心思想

以存储 unsigned long long64 位无符号整数)为例。

对于每一个元素,将其所对应的二进制当做一个长度为 64 的字符串插入到一个 trie 中进行存储(也就是所谓的 0-1 trie)。

下面给出一个简单的例子,实现了插入和查找出现次数的功能。

struct Node {
    int endCount;      // 存储以当前节点为结尾的数字串个数
    Node *children[2]; // children[0] 是当前节点的 '0' 字符指向的节点(左儿子),children[1] 是当前节点的 '1' 字符指向的节点(右儿子)
    Node() : endCount(0) {
        // 默认构造函数
        children[0] = children[1] = nullptr;
    }
} *root = new Node; // 根节点

void insert(long long value) {
    // 插入 value
    Node *current = root;
    bool direction;
    for (int i = 63; i >= 0; current = current->children[direction]) // 从高位到低位遍历
        if (current->children[direction = (value >> i--) & 1 /* 取当前位 */] == nullptr)
            current->children[direction] = new Node; // 不存在则新开节点
    current->endCount++; // 更新信息
}

int count(long long value) {
    // 统计 value 出现次数
    Node *current = root;
    bool direction;
    for (int i = 63; i >= 0; current = current->children[direction])
        if (current->children[direction = (value >> i--) & 1] == nullptr)
            return 0;
    return current->endCount;
}

平衡树六大操作的实现

平衡树的基本操作无非就是这 6 个:

我们一个一个拆解它们。

插入

就是常规 trie 的插入操作。

void insert(long long value) {
    // 插入 value
    Node *current = root;
    bool direction;
    for (int i = 63; i >= 0; current = current->children[direction]) // 从高位到低位遍历
        if (current->children[direction = (value >> i--) & 1 /* 取当前位 */] == nullptr)
            current->children[direction] = new Node; // 不存在则新开节点
    current->endCount++; // 更新信息
}

删除

正常我们用 trie 的时候似乎很少用到过删除操作。为了支持这一操作,我们需要维护每个节点被数字串经过的次数以及每个节点的父亲节点

struct Node {
    int endCount, existCount;   // existCount 统计当前节点被数字串经过的次数
    Node *children[2], *parent; // parent 指向父亲节点
    Node(Node *father = nullptr) : endCount(0), existCount(0), parent(father) {
        // 以 father 作为父亲节点构造此节点
        children[0] = children[1] = nullptr;
    }
};

如果我们想删除 x,那么我们先判断 x 是否存在。如果存在,我们找到 x 所对应数字串在 trie 上的那条路径中最下面的那个节点,将它的 endCount 减少 1,然后从下到上将路径上的每个节点的 existCount 减少 1,若一个节点的 existCount 减少至 0,则删除这个节点。

Node* find(long long value) {
    // 查找 value 所对应的数字串在 trie 上的路径中最下面的节点
    // 若 value 不存在则返回一个零指针 nullptr
    Node *current = root;
    bool direction;
    for (int i = 63; i >= 0; current = current->children[direction])
        if (current->children[direction = (value >> i--) & 1] == nullptr)
            return nullptr;
    return current->endCount ? current : nullptr;
}

void erase(long long value) {
    // 删除 value
    Node *current = find(value) /* 先查找 */, *parent;
    if (current != nullptr) {
        // value 存在
        for (current->endCount--; current != nullptr; current = parent) {
            // 自底向上遍历
            current->existCount--;
            parent = current->parent;
            if (current->existCount == 0 && parent != nullptr) {
                // 此处判 parent != nullptr 是为了防止把根节点删了
                // 当然你也可以判 current != root
                parent->children[current == parent->children[1]] = nullptr;
                // 清零父节点指针,这是一个很深金的写法
                delete current;
            }
        }
    }
}

由排名查值

这个就和常规 BST 很像了。对于一个非叶子节点,我们知道往左走得到的数肯定比往右走小,因为往左走意味着这一位是 0,剩余位无论怎么大也不会超过右边。

那么我们模仿 BST 的操作即可。假设我们要查排名为 rank 的值,那么我们从根节点开始走,若 rank 小于等于当前节点左儿子的 existCount 那就往左走,否则往右走并令将 rank 减去左儿子的 existCount(若当前节点没有左儿子,那它左儿子的 existCount 就为 0)。

一直这样走下去直到抵达叶子节点,这条路径就是我们查出来的值(因为我们插入的全都是 64 位数字串,所以一个数字串所对应的路径最下面的节点肯定是叶子节点)。

long long queryValue(int rank) {
    long long result = 0;
    bool direction;
    for (Node *current = root; rank != 0 && (current->children[0] != nullptr || current->children[1] != nullptr) /* 一直走到叶子 */; current = current->children[direction]) {
        result = (result << 1) | (direction = (rank > (current->children[0] == nullptr ? 0 : current->children[0]->existCount)));
        if (direction)
            rank -= current->children[0] == nullptr ? 0 : current->children[0]->existCount;
    }
    return result;
}

由值查排名

也是和 BST 的那个操作差不多。我们从根节点开始走,每当我们往右走的时候都将排名加上当前节点左儿子的 existCount,直到走不了了为止。

int queryRank(long long value) {
    int result = 1;
    Node *current = root;
    bool direction;
    for (int i = 63; i >= 0 && current != nullptr; current = current->children[direction])
        if ((direction = (value >> i--) & 1) && current->children[0] != nullptr)
            result += current->children[0]->existCount;
    return result;
}

查前驱

这里可以直接懒一下,因为我们维护的是整数,所以 queryPrevious(value) = queryValue(queryRank(value) - 1)

long long queryPrevious(long long value) {
    // 查 value 的前驱
    return queryValue(queryRank(value) - 1);
}

查后继

这里也可以懒一下。queryNext(value) = queryValue(queryRank(value + 1))

long long queryNext(long long value) {
    // 查 value 的后继
    return queryValue(queryRank(value + 1));
}

以上,平衡树六大基本操作全部实现。

至此,一锤定音。 尘埃,已然落定。

平衡树模板题通过记录。

需要注意的是,我们维护的是非负整数,对于有负数的情况我们需要将所有的数偏移一个常量 Delta 让所有的值都是正的。

优劣势分析

用 trie 实现的平衡树有优有劣。

优势

trie 实现平衡树的优势有很多:

劣势

当然它的劣势也有:

进阶使用

为了弥补 trie 实现的平衡树的不足,我们给出一些 trie 的进阶使用。

存储特殊类型数据

一个通用的方法是离线之后离散化,但是大部分需要使用平衡树的地方都要求在线。

然而对于一些特殊类型的数据,我们有特殊的方法可以在线处理。

(以字典序)存储字符串

这个非常的简单,直接把 0-1 trie 改写成普通 trie 即可。

在对一个给定的字符串进行操作时,需要先往它前面补极小元素(例如 \0)使得所有的字符串的长度都相等再进行操作(比如插入删除等)。返回的结果也记得先把前面补充的极小元删掉。

(以字典序)存储序列容器

和字符串类似,实际上是前者的推广。假设序列容器所存储的元素可以很方便地与一个较小的范围构成映射关系(例如字符类型可以很方便地映射到 [0,128) 这个区间中,只需要强行转成整形即可),那么就可以构造一个 trie 来存储它。

和字符串类似,在进行操作之前需要补 0。

存储浮点数

这里给出一种神奇的科技:直接将浮点类型所对应的二进制视作一个整型,然后把转出来的整形扔到 0-1 trie 里面存储。之所以能这样做,是因为把浮点数强转整形之后,它们仍然满足原有的大小关系

给出如下代码,大家可以自己试一下。

#include <iostream>

int main() {
    while (true) {
        double A;
        std::cout << "输入一个浮点数:";
        std::cin >> A;
        void *B = &A;
        long long *C = static_cast<long long*>(B);
        std::cout << "转化为的整数为 " << (*C) << '\n';
    }
    return 0;
}

trie 合并

就像线段树合并一样,我们可以写出 trie 树合并。

我们定义函数 Node* merge(Node* l, Node* r),作用是将以节点 r 为根的 trie 合并到以节点 l 为根的 trie 中。

首先我们需要先把 r 节点的信息合并到 l 节点中,即 l->endCount += r->endCount, l->existCount += r->existCount。然后我们递归调用这个函数分别合并其左右儿子即可。

Node* merge(Node *l, Node *r) {
    if (l == nullptr || r == nullptr)
        return l == nullptr ? r : l;
    l->endCount += r->endCount;
    l->existCount += r->existCount;
    l->children[0] = merge(l->children[0], r->children[0]);
    l->children[1] = merge(l->children[1], r->children[1]);
    return l;
}

用 trie 实现可持久化平衡树

这完全是可行的。注意到 trie 在插入的时候只会修改树上的一条路径,因此我们类比普通的可持久化数据结构,将这条路径进行复制,然后再更新,这样就可以获得一个新的版本。

然后我们开一个 vector 存储每一个版本的根节点就行了。但是需要注意的是,由于 trie 树的超大空间,在某些题目上用这种实现方法非常容易 MLE 导致爆 0。因此在你想用 trie 实现可持久化平衡树的时候,请先计算一下题目的空间限制是否允许你这样做。

这样是可以通过可持久化平衡树模板的。即使动态开点还用的是无优化 C++IO 也跑的很快。

模板

最后放两个模板吧。

用 0-1 trie 实现的存储无符号整数类型的平衡树

@file "TrieSet.hpp"

#ifndef TrieSet_hpp
#define TrieSet_hpp

#include <vector>
#include <functional>

namespace xmz {
    template <class Value>
    class TrieSet {
        static_assert(std::is_unsigned<Value>::value, "Value must be an unsigned integer");
        public:
            typedef Value valueType;
            typedef std::size_t sizeType;
            typedef std::ptrdiff_t differenceType;
        private:
            struct Node {
                sizeType existCount;
                Node *children[2], *parent;
                Node(Node *father = nullptr) : existCount(0), parent(father) {
                    children[0] = children[1] = nullptr;
                }
            };
        private:
            static const sizeType Length = sizeof(valueType) << 3;
            Node *root;
        public:
            TrieSet() : root(new Node) {}
            TrieSet(const valueType& value, sizeType count = 1) : root(new Node) {
                insert(value, count);
            }
            template <class inputIterator>
            TrieSet(inputIterator first, inputIterator last) : root(new Node) {
                for (inputIterator i = first; i != last; insert(*i++));
            }
            TrieSet(const TrieSet& copySource) = default;
            TrieSet(TrieSet&& moveSource) = default;
            TrieSet(std::initializer_list<valueType> initialList) : root(new Node) {
                assign(initialList);
            }
            ~TrieSet() = default;
            TrieSet& operator=(const TrieSet& copySource) = default;
            TrieSet& operator=(TrieSet&& moveSource) = default;
            TrieSet& operator=(std::initializer_list<valueType> initialList) {
                assign(initialList);
            }
            void assign(const valueType& value, sizeType count = 1) {
                clear();
                insert(value, count);
            }
            template <class inputIterator>
            void assign(inputIterator first, inputIterator last) {
                clear();
                for (inputIterator i = first; i != last; insert(*i++));
            }
            void assign(std::initializer_list<valueType> initialList) {
                clear();
                assign(initialList.begin(), initialList.end());
            }
            bool empty() const {
                return size() == 0;
            }
            sizeType size() const {
                return root->existCount;
            }
            void clear() {
                std::function<void(Node*)> DFS = [&](Node *current) -> void {
                    if (current != nullptr) {
                        DFS(current->children[0]);
                        DFS(current->children[1]);
                        delete current;
                    }
                };
                DFS(root);
                root = new Node;
            }
            void insert(const valueType& value, sizeType count = 1) {
                Node *current = root;
                bool direction;
                for (sizeType i = Length; i--; current = current->children[direction]) {
                    if (current->children[direction = (value >> i) & 1] == nullptr)
                        current->children[direction] = new Node(current);
                    current->existCount += count;
                }
                current->existCount += count;
            }
            void erase(const valueType& value, sizeType count = 1) {
                Node *current = root, *parent;
                bool direction;
                for (sizeType i = Length; i--; current = current->children[direction])
                    if (current->children[direction = (value >> i) & 1] == nullptr)
                        return;
                if (current != nullptr) {
                    for (; current != nullptr; current = parent) {
                        current->existCount -= count;
                        parent = current->parent;
                        if (current->existCount == 0 && parent != nullptr) {
                            parent->children[current == parent->children[1]] = nullptr;
                            delete current;
                        }
                    }
                }
            }
            void eraseAll(const valueType& value) {
                erase(value, count(value));
            }
            bool exist(const valueType& value) {
                return count(value) != 0;
            }
            sizeType count(const valueType& value) {
                Node *current = root;
                bool direction;
                for (sizeType i = Length; i--; current = current->children[direction])
                    if (current->children[direction = (value >> i) & 1] == nullptr)
                        return 0;
                return current->existCount;
            }
            valueType queryValue(sizeType rank) const {
                valueType result = 0;
                bool direction;
                for (Node *current = root; rank != 0 && (current->children[0] != nullptr || current->children[1] != nullptr); current = current->children[direction]) {
                    result = (result << 1) | (direction = (rank > (current->children[0] == nullptr ? 0 : current->children[0]->existCount)));
                    if (direction)
                        rank -= current->children[0] == nullptr ? 0 : current->children[0]->existCount;
                }
                return result;
            }
            sizeType queryRank(const valueType& value) const {
                sizeType result = 1;
                Node *current = root;
                bool direction;
                for (sizeType i = Length; i-- && current != nullptr; current = current->children[direction])
                    if ((direction = (value >> i) & 1) && current->children[0] != nullptr)
                        result += current->children[0]->existCount;
                return result;
            }
            valueType queryPrevious(const valueType& value) const {
                return queryValue(queryRank(value) - 1);
            }
            valueType queryNext(const valueType& value) const {
                return queryValue(queryRank(value + 1));
            }
            std::vector<valueType> traversal() const {
                std::vector<valueType> result;
                result.reserve(size());
                std::function<void(Node*, valueType)> DFS = [&](Node *current, valueType value) -> void {
                    if (current != nullptr) {
                        if (current->children[0] == nullptr && current->children[1] == nullptr) {
                            for (sizeType i = current->existCount; i--; result.push_back(value));
                            return;
                        }
                        DFS(current->children[0], value << 1);
                        DFS(current->children[1], (value << 1) | 1);
                    }
                };
                return DFS(root, 0), result;
            }
            void merge(TrieSet& sourceTrie) {
                std::function<Node*(Node*, Node*)> DFS = [&](Node *destination, Node *source) -> Node* {
                    if (destination == nullptr || source == nullptr)
                        return destination == nullptr ? source : destination;
                    destination->existCount += source->existCount;
                    destination->children[0] = DFS(destination->children[0], source->children[0]);
                    destination->children[1] = DFS(destination->children[1], source->children[1]);
                    return destination;
                };
                DFS(root, sourceTrie.root);
            }
            bool operator==(const TrieSet& rhs) const {
                return traversal() == rhs.traversal();
            }
            bool operator!=(const TrieSet& rhs) const {
                return traversal() != rhs.traversal();
            }
            bool operator<(const TrieSet& rhs) const {
                return traversal() < rhs.traversal();
            }
            bool operator<=(const TrieSet& rhs) const {
                return traversal() <= rhs.traversal();
            }
            bool operator>(const TrieSet& rhs) const {
                return traversal() > rhs.traversal();
            }
            bool operator>=(const TrieSet& rhs) const {
                return traversal() >= rhs.traversal();
            }
    };
}

#endif

用 0-1 trie 实现的存储无符号整数类型的可持久化平衡树

@file "PersistentTrieSet.hpp"

#ifndef PersistentTrieSet_hpp
#define PersistentTrieSet_hpp

#include <vector>
#include <functional>

namespace xmz {
    template <class Value>
    class PersistentTrieSet {
        static_assert(std::is_unsigned<Value>::value, "Value must be an unsigned integer");
        public:
            typedef Value valueType;
            typedef std::size_t sizeType;
            typedef std::ptrdiff_t differenceType;
        private:
            struct Node {
                sizeType existCount;
                Node *children[2], *parent;
                Node(Node *father = nullptr) : existCount(0), parent(father) { children[0] = children[1] = nullptr; }
                void copyData(const Node* sourcePointer) {
                    if (sourcePointer != nullptr) {
                        existCount = sourcePointer->existCount;
                        children[0] = sourcePointer->children[0];
                        children[1] = sourcePointer->children[1];
                    }
                }
            };
        private:
            static const sizeType Length = sizeof(valueType) << 3;
            std::vector<Node*> roots;
        public:
            PersistentTrieSet() : roots(1, new Node) {}
            PersistentTrieSet(const PersistentTrieSet& copySource) = default;
            PersistentTrieSet(PersistentTrieSet&& moveSource) = default;
            ~PersistentTrieSet() = default;
            PersistentTrieSet& operator=(const PersistentTrieSet& copySource) = default;
            PersistentTrieSet& operator=(PersistentTrieSet&& moveSource) = default;
            bool empty(sizeType version = 0) const {
                return size(version) == 0;
            }
            sizeType size(sizeType version = 0) const {
                return roots[version]->existCount;
            }
            bool exist(const valueType& value, sizeType version = 0) const {
                return count(value, version) != 0;
            }
            sizeType count(const valueType& value, sizeType version = 0) const {
                Node *current = roots[version];
                bool direction;
                for (sizeType i = Length; i--; current = current->children[direction])
                    if (current->children[direction = (value >> i) & 1] == nullptr)
                        return 0;
                return current->existCount;
            }
            valueType queryValue(sizeType rank, sizeType version = 0) const {
                valueType result = 0;
                bool direction;
                for (Node *current = roots[version]; rank != 0 && (current->children[0] != nullptr || current->children[1] != nullptr); current = current->children[direction]) {
                    result = (result << 1) | (direction = (rank > (current->children[0] == nullptr ? 0 : current->children[0]->existCount)));
                    if (direction)
                        rank -= current->children[0] == nullptr ? 0 : current->children[0]->existCount;
                }
                return result;
            }
            sizeType queryRank(const valueType& value, sizeType version = 0) const {
                sizeType result = 1;
                Node *current = roots[version];
                bool direction;
                for (sizeType i = Length; i-- && current != nullptr; current = current->children[direction])
                    if ((direction = (value >> i) & 1) && current->children[0] != nullptr)
                        result += current->children[0]->existCount;
                return result;
            }
            valueType queryPrevious(const valueType& value, sizeType version = 0) const {
                return queryValue(queryRank(value, version) - 1, version);
            }
            valueType queryNext(const valueType& value, sizeType version = 0) const {
                return queryValue(queryRank(value + 1, version), version);
            }
            std::vector<valueType> traversal(sizeType version = 0) const {
                std::vector<valueType> result;
                result.reserve(size());
                std::function<void(Node*, valueType)> DFS = [&](Node *current, valueType value) -> void {
                    if (current != nullptr) {
                        if (current->children[0] == nullptr && current->children[1] == nullptr) {
                            for (sizeType i = current->existCount; i--; result.push_back(value));
                            return;
                        }
                        DFS(current->children[0], value << 1);
                        DFS(current->children[1], (value << 1) | 1);
                    }
                };
                return DFS(roots[version], 0), result;
            }
            bool noVersion() const {
                return size() == 0;
            }
            sizeType countVersions() const {
                return roots.size() - 1;
            }
            void copyVersion(sizeType sourceVersion) {
                roots.push_back(roots[sourceVersion]);
            }
            void eraseLatestVersion() {
                roots.pop_back();
            }
            bool NVEmpty(sizeType version) {
                return copyVersion(version), empty(version);
            }
            sizeType NVSize(sizeType version) {
                return copyVersion(version), size(version);
            }
            bool NVExist(sizeType version, const valueType& value) {
                return copyVersion(version), exist(value, version);
            }
            sizeType NVCount(sizeType version, const valueType& value) {
                return copyVersion(version), count(value, version);
            }
            valueType NVQueryValue(sizeType version, sizeType rank) {
                return copyVersion(version), queryValue(rank, version);
            }
            sizeType NVQueryRank(sizeType version, const valueType& value) {
                return copyVersion(version), queryRank(value, version);
            }
            valueType NVQueryPrevious(sizeType version, const valueType& value) {
                return copyVersion(version), queryPrevious(value, version);
            }
            valueType NVQueryNext(sizeType version, const valueType& value) {
                return copyVersion(version), queryNext(value, version);
            }
            void NVInsert(sizeType version, const valueType& value, sizeType count = 1) {
                roots.push_back(new Node);
                Node *oldRoot = roots[version], *current = roots.back();
                bool direction;
                current->copyData(oldRoot);
                for (sizeType i = Length; i--; current->copyData(oldRoot)) {
                    current->existCount += count;
                    current->children[direction = (value >> i) & 1] = new Node(current);
                    current = current->children[direction];
                    oldRoot = oldRoot == nullptr ? nullptr : oldRoot->children[direction];
                }
                current->existCount += count;
            }
            void NVErase(sizeType version, const valueType& value, sizeType count = 1) {
                roots.push_back(roots[version]);
                if (exist(value, version)) {
                    roots.back() = new Node;
                    Node *oldRoot = roots[version], *current = roots.back();
                    bool direction;
                    current->copyData(oldRoot);
                    for (sizeType i = Length; i--; current->copyData(oldRoot)) {
                        if ((current->existCount -= count) == 0) {
                            if (current->parent != nullptr) {
                                current->parent->children[current == current->parent->children[1]] = nullptr;
                                delete current;
                            }
                            return;
                        }
                        current->children[direction = (value >> i) & 1] = new Node(current);
                        current = current->children[direction];
                        oldRoot = oldRoot == nullptr ? nullptr : oldRoot->children[direction];
                    }
                    if ((current->existCount -= count) == 0) {
                        current->parent->children[current == current->parent->children[1]] = nullptr;
                        delete current;
                    }
                }
            }
            void NVEraseAll(sizeType version, const valueType& value) {
                NVErase(version, value, count(value, version));
            }
    };
}

#endif

结语

传统艺能

以上,我们探讨了字典树这一数据结构的概念、结构特性以及它的应用场景。

从基础的插入和搜索操作,到实现平衡树以及可持久化等高级功能,字典树以其独特的优势在数据处理和搜索引擎中扮演着不可或缺的角色。

随着技术的不断进步,字典树的应用也在不断扩展,从简单的字符串匹配,维护数据到复杂的自然语言处理,它都展现出了强大的生命力和灵活性。

我希望这篇文章能够帮助读者更好地理解字典树的工作原理,并激发大家对这一数据结构更深层次探索的兴趣。

在未来的学习和实践中,字典树无疑将成为解决各种问题的有力工具。让我们一起期待字典树在新技术中的应用,以及它如何帮助我们构建更加智能和高效的系统。

感谢您的阅读,如果您对字典树有更深入的问题或想要探讨相关话题,我们欢迎您的反馈和讨论。