题解:P5637 ckw的树

Kev1nL1kesCod1ng

2024-11-20 19:11:11

Solution

模拟赛时想到一个的绝妙做法。

下文中的 v 是指与 u 相邻的节点,su 的儿子,ru_u 为点 u 的度数,su_u=\sum ru_v

f_u 为从点 u 开始走的期望步数,则可以列出方程:

f_u=\frac{\sum_{dis(u,k)\le 2} f_k}{su_u+1}+1

若为点 u 为终点,则 f_u=0

但是方程 u 与较远的 k 有关系,不好算。能不能只和相邻的点有关系?答案是可以的,设 g_u

g_u=\sum f_v

那么算 f_u 只需要容斥一下:

f_u=\frac{\sum f_v + \sum g_v - ru_u f_u + f_u}{su_u + 1}+1

f_u 都移到一边:

f_u=\frac{\sum f_v + \sum g_v + su_u + 1}{su_u + ru_u}

这样所有的式子就只和相邻项有关系了。

考虑如何解这些方程,直接做就是高斯消元了,但是我们需要更优的做法。

考虑小学二年级是如何解方程的,没错就是代入消元,不难发现这里可以先从叶子开始代入消元,譬如叶子 u,有 g_u=f_{fa},考虑用 f_{fa}g_{fa} 表示 f_ug_u,形如:

[f_u,g_u]=[f_{fa},g_{fa},1] \begin{bmatrix} \dots & \dots \\ \dots & \dots \\ \dots & \dots \\ \end{bmatrix} f_u=[f_{fa},g_{fa},1]F_u \\ g_u=[f_{fa},g_{fa},1]G_u

其中 F_uG_u 都表示大小一乘三的矩阵。

这样点 u 的值就能用父亲来表示,而点 u 方程中含有有关儿子 s 的项,都可以用 s 的父亲 u 表示,这样方程就只和父亲有关了。

因为叶子节点没有儿子,所以 dfs,从叶子开始,一步一步向上代入消元即可。

这样的话,点 u 的方程消掉儿子,就变成:

g_u=f_{fa}+[f_u,g_u,1]\sum F_s f_u=\frac{[f_u,g_u,1](\sum F_s+\sum G_s)+f_{fa} + g_{fa} + su_u + 1}{su_u + ru_u}

这里用任何方法解出方程即可,比如暴力推式子或者搞大小二乘五矩阵的高斯消元。

然后对于终点的 g_u 特判一下即可。

然后求出根节点 F_{rt}G_{rt} 后,不难想到因为根节点没有父亲,所以没有 f_{fa}g_{fa},所以 f_{rt} 的值即为 F_{rt} 的常数项 {F_{rt}}_{3,1},同理 g_{rt}={G_{rt}}_{3,1}

求出根节点的两个值后,再做一次 dfs 把所有点的 f_ug_u 全部反推出来即可。

忽略逆元复杂度,时间复杂度为 O(n)

const int N = 1e5 + 5;
const int P = 998244353;
int add(int x, int y) { return (x + y < P ? x + y : x + y - P); }
void Add(int & x, int y) { x = (x + y < P ? x + y : x + y - P); }
int sub(int x, int y) { return (x < y ? x - y + P : x - y); }
void Sub(int & x, int y) { x = (x < y ? x - y + P : x - y); }
int mul(int x, int y) { return (1ll * x * y) % P; }
void Mul(int & x, int y) { x = (1ll * x * y) % P; }
int fp(int x, int y) {
    int res = 1;
    for(; y; y >>= 1) {
        if(y & 1) Mul(res, x);
        Mul(x, x);
    }
    return res;
}
int n, m, a[N];
int fi[N], ne[N << 1], to[N << 1], ecnt;
int ru[N], su[N];
struct Node { // 1x3 矩阵
    int f, g, c;
    friend Node operator + (Node A, Node B) {
        Node res;
        res.f = add(A.f, B.f);
        res.g = add(A.g, B.g);
        res.c = add(A.c, B.c);
        return res;
    }
    friend Node operator - (Node A, Node B) {
        Node res;
        res.f = sub(A.f, B.f);
        res.g = sub(A.g, B.g);
        res.c = sub(A.c, B.c);
        return res;
    }
    friend Node operator * (Node A, int B) {
        Node res;
        res.f = mul(A.f, B);
        res.g = mul(A.g, B);
        res.c = mul(A.c, B);
        return res;
    }
    Node & operator += (const Node &A) {
        return *this = *this + A;
    }
    Node & operator -= (const Node &A) {
        return *this = *this - A;
    }
    Node & operator *= (const int &A) {
        return *this = *this * A;
    }
} f[N], g[N];
int A[2][5]; // 用于高斯消元解小方程的 2x5 矩阵
int ans[N][2];
void add_edge(int u, int v) {
    ne[++ ecnt] = fi[u];
    to[ecnt] = v;
    fi[u] = ecnt;
}
void guass() { // 高斯消元解小方程
    if(! A[0][0]) swap(A[0], A[1]);
    int val = fp(A[0][0], P - 2);
    REP(i, 5) Mul(A[0][i], val);
    val = A[1][0];
    REP(i, 5) Sub(A[1][i], mul(A[0][i], val));
    val = fp(A[1][1], P - 2);
    REP(i, 5) Mul(A[1][i], val);
    val = A[0][1];
    REP(i, 5) Sub(A[0][i], mul(A[1][i], val));
    REP(i, 5) A[0][i] = sub(0, A[0][i]);
    REP(i, 5) A[1][i] = sub(0, A[1][i]);
}
void dfs1(int u, int fa) {
    Node F = {0, 0, 0}, G = {0, 0, 0};
    for(int i = fi[u]; i; i = ne[i]) {
        int v = to[i];
        if(v == fa) continue;
        dfs1(v, u);
        F += f[v], F += g[v];
        G += f[v];
    }
    if(! a[u]) {
        A[0][0] = G.f;
        A[0][1] = sub(G.g, 1);
        A[0][2] = 1;
        A[0][3] = 0;
        A[0][4] = G.c;
        A[1][0] = F.f;
        A[1][1] = F.g;
        A[1][2] = 1;
        A[1][3] = 1;
        A[1][4] = add(F.c, su[u] + 1);
        int val = fp(su[u] + ru[u], P - 2);
        REP(i, 5) Mul(A[1][i], val);
        Sub(A[1][0], 1);
        guass();
        f[u] = {A[0][2], A[0][3], A[0][4]};
        g[u] = {A[1][2], A[1][3], A[1][4]};
    }
    else { // 是终点要特判一下
        int val = sub(1, G.g);
        g[u].f = 1, g[u].c = G.c;
        g[u] *= fp(val, P - 2);
    }
} 
void dfs2(int u, int fa) {
    for(int i = fi[u]; i; i = ne[i]) {
        int v = to[i];
        if(v == fa) continue;
        Add(ans[v][0], mul(ans[u][0], f[v].f));
        Add(ans[v][0], mul(ans[u][1], f[v].g));
        Add(ans[v][0], f[v].c);
        Add(ans[v][1], mul(ans[u][0], g[v].f));
        Add(ans[v][1], mul(ans[u][1], g[v].g));
        Add(ans[v][1], g[v].c);
        dfs2(v, u);
    }
}
void solve() {
    cin >> n >> m;
    REP(_, n - 1) {
        int u, v;
        cin >> u >> v;
        add_edge(u, v);
        add_edge(v, u);
        ru[u] ++, ru[v] ++;
    }
    FOR(u, 1, n) for(int i = fi[u]; i; i = ne[i]) {
        int v = to[i];
        Add(su[u], ru[v]);
    }
    REP(_, m) {
        int x; cin >> x;
        a[x] = 1;
    }
    dfs1(1, 0);
    ans[1][0] = f[1].c;
    ans[1][1] = g[1].c;
    dfs2(1, 0);
    FOR(i, 1, n) cout << ans[i][0] << endl;
}