从 LCT 到 SATT

zzzYheng

2025-01-06 21:59:27

Solution

参考资料:《浅谈一类实现简易的动态树型结构信息维护方法》肖岱恩。

本文不带复杂度证明,大家有兴趣就去看 X 神的论文吧。

众所周知,LCT 存在难以维护子树信息的缺点,因此我们考虑改进 LCT 以使其能维护子树信息。

由于 LCT 其实维护的是辅助树的结构,因此我们在辅助树上考虑原树的子树,容易发现 cut(fa, x)makeRoot(x)x 的子树就是以 x 为根的那棵辅助树上的所有节点了,因此我们只需在支持 LCT 的基本操作的同时维护辅助树的子树信息合并就可以了。

这个问题看起来是不难的,因为 LCT 本身就已经维护了每个点在辅助树上向下沿实边能到达的连通块的信息合并,我们只需要再增加一个沿虚边的信息合并就可以了!不过一个很严重的问题是:虚边可能是不止 \Theta(1) 条的,这不就寄了吗?

SATT 给出的一个解决方案是,我们再对每个节点的虚儿子建立一个 Splay Leafy Tree 的结构来维护虚儿子的信息合并,并将这个 SLT 的根作为实节点的中儿子以建立实节点和它的虚儿子们之间的联系。

比如下面这棵树:

其的 SATT 可能就会长这样:

比如图中的 7 号节点就是为了合并 1 的 4、5 两个虚儿子的信息而建立的虚节点,同时为了结构的统一性,即使 2 号节点只有 6 一个虚儿子,我们还是会建立一个虚节点 8 来作为 2 的中儿子。

有了这个结构之后,我们就能很方便地维护虚儿子的信息合并了。同时对子树修改也是简单的了,因为我们可以通过这个结构在辅助树上自顶向下地下传修改标记,那么任何双半群信息的子树修改、查询都可以使用 SATT 维护了!

下面是维护细节,我们只考虑 access(x) 这个函数,实现了这个函数后其他函数的实现是简单的:

如何分析其时间复杂度,感受一下,access(x) 的操作大致可以看为在整棵辅助树上 "splay" 了 x,那么 maccess(x) 的复杂度为 \Theta((n+m)\log{n}),更为严谨的分析可以参考 X 神的论文。

最后注意一些细节:

最后附上代码:

#include <bits/stdc++.h>

using namespace std;

const int kMaxN = 1e5 + 10;
const int kInf = 2e9;

int wtop, wlen, wstk[40];
char rdc[1<<14], wtc[1<<23], *rS, *rT;
#define getchar() (rS==rT?rT=(rS=rdc)+fread(rdc,1,1<<14,stdin),(rS==rT?EOF:*rS++):*rS++)
#define putchar(x) wtc[wlen++]=(x)
#define flush() fwrite(wtc,1,wlen,stdout),wlen=0

template<typename T>
T read() {
    char c = getchar(); T s = 0;
    for (; !isdigit(c); c = getchar());
    for (; isdigit(c); c = getchar()) s = s * 10 + (c ^ 48);
    return s;
}

template<typename T>
// 文末记得 flush 一下
void write(T x) {
    if(wlen>=8000000)flush();
  if (x < 0) {
    putchar('-');
    write(-x);
    return;
  }
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

int n, q;

struct Tag {
    int k;
    int b; // x <- x * k + b

    Tag(int _k = 1, int _b = 0) {
        k = _k, b = _b;
    }

    void operator += (const Tag &O) {
        if (!O.k) *this = O;
        else (*this).b += O.b;
    }
};

struct Data {
    int mx, mn;
    int sum, siz;

    Data(int _mx = -kInf, int _mn = kInf, int _sum = 0, int _siz = 0) {
        mx = _mx, mn = _mn, sum = _sum, siz = _siz;
    }

    Data operator + (const Data &O) const {
        return Data(max(mx, O.mx), min(mn, O.mn), sum + O.sum, siz + O.siz);
    }

    void operator += (const Tag &O) {
        if (mn == kInf) return;
        if (!O.k) mx = mn = O.b, sum = O.b * siz;
        else mx += O.b, mn += O.b, sum += siz * O.b;
    }
};

struct Node {
    int fa, ch[3];
    Data val, sum_chain, sum_subtree;
    Tag tag_chain, tag_subtree;
    bool rev_tag;

    Node() {
        fa = ch[0] = ch[1] = ch[2] = 0;
        val = sum_chain = sum_subtree = Data();
        tag_chain = tag_subtree = Tag();
        rev_tag = 0;
    }
} tree[kMaxN << 1];
vector<int> node_stk;
int root;

int getNewNode() {
    int id = node_stk.back(); node_stk.pop_back();
    tree[id] = Node();
    return id;
}

int ls(int x) { return tree[x].ch[0]; }
int rs(int x) { return tree[x].ch[1]; }
int ms(int x) { return tree[x].ch[2]; }
int fa(int x) { return tree[x].fa; }
int get(int x) { 
    int y = fa(x);
    if (ls(y) == x) return 0;
    if (rs(y) == x) return 1;
    return 2; 
}
bool isSplayRoot(int x) { return !fa(x) || (fa(x) > n) != (x > n); }
bool isAuxRoot(int x) { return !fa(x); }

void pushUp(int x) {
    if (x > n) {
        tree[x].sum_subtree = tree[ls(x)].sum_chain + tree[ls(x)].sum_subtree + tree[rs(x)].sum_chain + tree[rs(x)].sum_subtree;
    }
    else {
        tree[x].sum_chain = tree[x].val + tree[ls(x)].sum_chain + tree[rs(x)].sum_chain;
        tree[x].sum_subtree = tree[ms(x)].sum_subtree + tree[ls(x)].sum_subtree + tree[rs(x)].sum_subtree;
    }
}

void putTag(int x, Tag tag_chain, Tag tag_subtree) {
    if (x > n) {
        tree[x].sum_subtree += tag_subtree;
        tree[x].tag_subtree += tag_subtree;
    }
    else {
        tree[x].val += tag_chain;
        tree[x].sum_chain += tag_chain;
        tree[x].tag_chain += tag_chain;
        tree[x].sum_subtree += tag_subtree;
        tree[x].tag_subtree += tag_subtree;
    }
}

void pushDown(int x) {
    if (x > n) {
        if (ls(x)) putTag(ls(x), tree[x].tag_subtree, tree[x].tag_subtree);
        if (rs(x)) putTag(rs(x), tree[x].tag_subtree, tree[x].tag_subtree);
        tree[x].tag_subtree = Tag();
    }
    else {
        if (tree[x].rev_tag) {
            if (ls(x)) {
                swap(tree[ls(x)].ch[0], tree[ls(x)].ch[1]);
                tree[ls(x)].rev_tag ^= 1;
            }
            if (rs(x)) {
                swap(tree[rs(x)].ch[0], tree[rs(x)].ch[1]);
                tree[rs(x)].rev_tag ^= 1;
            }
            tree[x].rev_tag = 0;
        }
        if (ls(x)) putTag(ls(x), tree[x].tag_chain, tree[x].tag_subtree);
        if (rs(x)) putTag(rs(x), tree[x].tag_chain, tree[x].tag_subtree);
        if (ms(x)) putTag(ms(x), Tag(), tree[x].tag_subtree);
        tree[x].tag_chain = tree[x].tag_subtree = Tag();
    }
}

void pushDownFromAuxRoot(int x) {
    if (!isAuxRoot(x)) pushDownFromAuxRoot(fa(x));
    pushDown(x);
}

void pushDownFromSplayRoot(int x) {
    if (!isSplayRoot(x)) pushDownFromSplayRoot(fa(x));
    pushDown(x);
}

void rotate(int x) {
    int y = fa(x), z = fa(y), t = get(x);
    tree[z].ch[get(y)] = x;
    tree[y].ch[t] = tree[x].ch[!t], tree[tree[y].ch[t]].fa = y;
    tree[x].ch[!t] = y, tree[y].fa = x;
    tree[x].fa = z;
    pushUp(y);
}

void splay(int x) {
    while (!isSplayRoot(x)) {
        int y = fa(x);
        if (!isSplayRoot(y)) rotate(get(y) == get(x) ? y : x);
        rotate(x);
    }
    pushUp(x);
}

void insertForVirSon(int x, int y) {
    int z = ms(x), t = getNewNode();
    pushDown(x);
    tree[x].ch[2] = t, tree[t].fa = x;
    tree[t].ch[0] = z, tree[t].ch[1] = y;
    tree[z].fa = t, tree[y].fa = t;
    pushUp(t), pushUp(x);
}

void access(int x) {
    int tmp = x;
    pushDownFromAuxRoot(x);
    splay(x);
    if (rs(x)) insertForVirSon(x, rs(x)), tree[x].ch[1] = 0;
    while (!isAuxRoot(x)) {
        int y = fa(x); 
        int z = y;
        while (z > n) z = fa(z);
        splay(z);
        if (rs(z)) {
            int t = rs(z);
            tree[y].ch[get(x)] = t, tree[t].fa = y;
            tree[z].ch[1] = x, tree[x].fa = z;
            splay(y);
        }
        else {
            int bro = tree[y].ch[!get(x)];
            node_stk.emplace_back(y);
            int f = fa(y);
            tree[f].ch[get(y)] = bro, tree[bro].fa = f;
            tree[z].ch[1] = x, tree[x].fa = z;
            if (f > n) splay(f);
        }
        x = z;
    }
    splay(tmp);
}

void makeRoot(int x) {
    access(x);
    swap(tree[x].ch[0], tree[x].ch[1]);
    tree[x].rev_tag ^= 1;
}

int getRoot(int x) {
    access(x);
    int cur = x;
    pushDown(x);
    while (ls(cur)) cur = ls(cur), pushDown(cur);
    splay(cur);
    return cur;
}

int getOgFa(int x) {
    if (x == root) return 0;
    makeRoot(root);
    access(x);
    int cur = ls(x);
    pushDown(x), pushDown(cur);
    while (rs(cur)) cur = rs(cur), pushDown(cur);
    splay(cur);
    return cur;
}

void splitPath(int x, int y) {
    // splay root is x
    makeRoot(y);
    access(x);
}

void link(int x, int y) {
    makeRoot(x), makeRoot(y);
    insertForVirSon(x, y);
}

void cut(int x, int y) {
    makeRoot(x), access(y);
    pushDown(y);
    tree[y].ch[0] = 0, tree[x].fa = 0;
    pushUp(y);
}

int main() {
    n = read<int>(), q = read<int>();
    vector<pair<int, int> > edge;
    for (int i = 1; i < n; ++i) {
        int u = read<int>(), v = read<int>();
        edge.emplace_back(u, v);
    }
    for (int i = 1; i <= n; ++i) {
        int val = read<int>();
        tree[i].val = tree[i].sum_chain = Data(val, val, val, 1);
    }
    for (int i = 1; i <= n; ++i) node_stk.emplace_back(i + n);
    root = read<int>();
    for (auto it : edge) link(it.first, it.second);

    while (q--) {
        int opt = read<int>();
        if (opt == 0 || opt == 5) {
            int x = read<int>(), y = read<int>();
            int fa = getOgFa(x);
            if (fa) access(fa);
            else makeRoot(root);
            Tag tag = Tag((!opt) ? 0 : 1, y);
            putTag(x, tag, tag);
            access(x);
        }
        else if (opt == 1) {
            root = read<int>();
        }
        else if (opt == 2 || opt == 6) {
            int x = read<int>(), y = read<int>(), z = read<int>();
            splitPath(x, y);
            Tag tag = Tag((opt == 2) ? 0 : 1, z);
            putTag(x, tag, Tag());
        }
        else if (opt == 3 || opt == 4 || opt == 11) {
            int x = read<int>();
            int fa = getOgFa(x);
            if (fa) access(fa);
            else makeRoot(root);
            Data data = tree[x].sum_chain + tree[x].sum_subtree;
            write((opt == 3) ? data.mn : (opt == 4) ? data.mx : data.sum), putchar('\n');
        }
        else if (opt == 7 || opt == 8 || opt == 10) {
            int x = read<int>(), y = read<int>();
            splitPath(x, y);
            write((opt == 7) ? tree[x].sum_chain.mn : (opt == 8) ? tree[x].sum_chain.mx : tree[x].sum_chain.sum), putchar('\n');
        }
        else {
            int x = read<int>(), y = read<int>();
            int fa = getOgFa(x);
            if (fa) {
                cut(fa, x);
                if (getRoot(x) == getRoot(y)) link(fa, x);
                else link(x, y);
            }
        }
    }
    flush();
    return 0;
}