前言
本文思路与此博客一样,部分引用已经过本人授权,不过他不想写了,所以我帮他完善/修正一下。
思路
首先观察一下合法的情况 S 满足什么条件,显然在这种情况下 S 能且仅能转到一种状态 T,同时 T 也只能转到 S,我们将用过的边染为红边,没用过的边染为黑边,选的点为黑点,没选的为白点,如下图所示
对于一条链来讲,肯定是一个白点加一堆黑点且链长度必须大于一,我们记这条链的白色点为头,最后面的黑色点为尾,中间的为身子,注意身子可能并不存在。
容易发现一些性质,头和尾显然只会连接一条红边,若只保留黑边的话,头只能连尾,尾只能连头,身子只能连身子,也就是说不然 S 到 T 后将面临无法划分的情况,如果不理解建议读者自行画图理解。
考虑树形 DP
,有以下 7 个状态:
转移也就是:
-
f_{x,0}= \prod_{y \in son_{x}} f_{y,3}
-
f_{x,1}= \sum_{y \in son_{x}}(f_{y,4}+f_{y,2}) \prod_{z \ne y}f_{z,3}
-
f_{x,2}=\prod _{y \in son_{x}} f_{y,1}
-
f_{x,3}= \sum_{y \in son_{x}}(f_{y,5}+f_{y,0})\prod_{z \ne y} f_{z,1}
-
f_{x,4}= \sum_{y \in son_{x}} (f_{y,2}+f_{y,4}) \prod_{z \ne y} f_{z,6}
-
f_{x,5}= \sum_{y \in son_{x}} (f_{y,0}+f_{y,5}) \prod_{z \ne y} f_{z,6}
-
f_{x,6} = \sum_{y \in son_{x},z \in son_{x},z\ne y} (f_{y,0}+f_{y,5})(f_{z,2}+f_{z,4}) \prod_{k \ne y,k \ne z} f_{k,6}
这个显然是会超时的,考虑优化。
$f_{x,1},f_{x,2}$ 直接预处理 $f_{z_3},f_{z_1}$ 的乘积,枚举 $y$ 的时候额外乘上 $f_{y,3},f_{z,1}$ 的逆元即可。
$f_{x,4},f_{x,5}$ 同理,参考上面的做法,不多叙述。
然后着重看 $f_{x,6}$ 的处理,枚举一个 $y$,预处理每个 $f_{z,2}+f_{z,4}$ 乘上所有 $f_{k,6}$ 的乘积在乘上 $f_{z,6}$ 的逆元,设为 $g_z$,然后记 $sum$ 为 $\sum_{o \in son_{x}} g_o$,则 $f_{x,6}$ 就等于 $(sum-g_y)\times(f_{y,0}+f_{y,5})/f_{y,6}$,容易发现除了 $y$ 本身其他的每个都多成了个 $f_{y,6}$,除去就好了。
这样本来就结束了,但细心的小朋友会发现,**0不能作为除数**。
不过没事,我们在开一些数组存为 $0$ 的个数有几个,这样就可以了,具体的可以看代码实现。
不过 $f_{x,6}$ 有些特殊,我们需要分讨解决,这是因为它需要枚举两个数,如果根其它的一样判 $0$,复杂度是不可接受的。
具体的,当 $sum_0 > 2$ 时,显然答案为 $0$。
当 $sum_0 = 2$ 时,显然选的数固定了,其他算出来都为 $0$。
当 $sum_0 = 1$ 时,我们新记一个 $sum2$ 表示不乘 $0$ 的贡献是多少,若 $f_{y,6} = 0$ 算贡献时不乘逆元即可,否则正常转。
当 $sum_0 = 0$ 时,正常转即可。
具体的细节和易错点可以看代码,加了点注释。
**code**
```cpp
#include<bits/stdc++.h>
#define int long long
using namespace std;
#define getchar() (p1 == p2 && (p2 = (p1 = buf1) + fread(buf1, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf1[1 << 23], *p1 = buf1, *p2 = buf1, ubuf[1 << 23], *u = ubuf;
namespace IO
{
template<typename T>
void read(T &_x){_x=0;int _f=1;char ch=getchar();while(!isdigit(ch)) _f=(ch=='-'?-1:_f),ch=getchar();while(isdigit(ch)) _x=_x*10+(ch^48),ch=getchar();_x*=_f;}
template<typename T,typename... Args>
void read(T &_x,Args&...others){Read(_x);Read(others...);}
const int BUF=20000000;char buf[BUF],to,stk[32];int plen;
#define pc(x) buf[plen++]=x
#define flush(); fwrite(buf,1,plen,stdout),plen=0;
template<typename T>inline void print(T x){if(!x){pc(48);return;}if(x<0) x=-x,pc('-');for(;x;x/=10) stk[++to]=48+x%10;while(to) pc(stk[to--]);}
}
using namespace IO;
const int N = 2e5+10,mod = 998244353;
int n,head[N],f[N][20],cnt,x,y;
int sum[N][20],sum1[N][20],op,op1,X,Y,l;
int o,p;
/*
f_i_0 i为头,连接父亲的为红边
f_i_1 i为头,连接父亲的为黑边
f_i_2 i为尾,连接父亲的为红边
f_i_3 i为尾,连接父亲的为黑边
f_i_4 i为身子,开头是头
f_i_5 i为身子,开头是尾
f_i_6 i为身子,连接父亲的为黑边,也就是从子树中选出两条不同的链,一个是头,一个是尾,也就是f_u_5,f_u_0和f_v_4,f_v_2,剩下的都断开,取f_x_6
容易观察到,只保留黑边的话,头只能连尾,尾只能连头,身子只能连身子
*/
struct w
{
int to,nxt;
}b[N<<1];
inline int ksm(int x,int p)
{
x %= mod;//避免爆long long
int ans = 1;
while(p)
{
if((p&1)) ans = ans*x%mod;
x = x*x%mod;
p >>= 1;
}
return ans;
}
inline void add(int x,int y)
{
b[++cnt].nxt = head[x];
b[cnt].to = y;
head[x] = cnt;
}
void dfs(int x,int y)
{
f[x][0] = f[x][2] = 1; //初值
int fu = 0;
for(int i = 0;i <= 10;i++) sum[x][i] = 1,sum1[x][i] = 0;
sum[x][5] = sum[x][3] = 0;
for(int i = head[x];i;i = b[i].nxt)
if(b[i].to != y)
{
fu++;
dfs(b[i].to,x);
if(f[b[i].to][3] == 0) sum1[x][0]++;
else sum[x][0] = sum[x][0]*f[b[i].to][3]%mod;
if(f[b[i].to][1] == 0) sum1[x][1]++;
else sum[x][1] = sum[x][1]*f[b[i].to][1]%mod;
if(f[b[i].to][6] == 0) sum1[x][4]++;
else sum[x][4] = sum[x][4]*f[b[i].to][6]%mod;
}
for(int i = head[x];i;i = b[i].nxt)
if(b[i].to != y)
{
if(f[b[i].to][6] == 0) sum1[x][4]--;
else sum[x][4] = sum[x][4]*ksm(f[b[i].to][6],mod-2)%mod;
op = (sum1[x][4]==0)*sum[x][4],sum[x][5] = (sum[x][5]+op*(f[b[i].to][2]+f[b[i].to][4])%mod)%mod;//算上0时的贡献
op = sum[x][4],sum[x][3] = (sum[x][3]+op*(f[b[i].to][2]+f[b[i].to][4])%mod)%mod;//不算上0时的贡献
if(f[b[i].to][6] == 0) sum1[x][4]++;
else sum[x][4] = sum[x][4]*f[b[i].to][6]%mod;
}
op = (sum1[x][0]==0)*sum[x][0],f[x][0] = f[x][0]*op%mod;
op = (sum1[x][1]==0)*sum[x][1],f[x][2] = f[x][2]*op%mod;
l = 0;
if(sum1[x][4] == 2)//有两个0
{
l = 1;
X = Y = 0;
for(int i = head[x];i;i = b[i].nxt)
if(b[i].to != y && f[b[i].to][6] == 0)
{
if(X == 0) X = b[i].to;
else Y = b[i].to;
}
f[x][6] = ((f[X][0]+f[X][5])%mod*((f[Y][2]+f[Y][4])%mod)%mod+(f[X][2]+f[X][4])%mod*((f[Y][0]+f[Y][5])%mod)%mod)*sum[x][4]%mod;
}
else if(fu == 2)//这是个很特殊的,只有两个点因为根本不需要枚举k
{
l = 1;
X = Y = 0;
for(int i = head[x];i;i = b[i].nxt)
if(b[i].to != y)
{
if(X == 0) X = b[i].to;
else Y = b[i].to;
}
f[x][6] = ((f[X][0]+f[X][5])%mod*((f[Y][2]+f[Y][4])%mod)%mod+(f[X][2]+f[X][4])%mod*((f[Y][0]+f[Y][5])%mod)%mod)%mod;
}
for(int i = head[x];i;i = b[i].nxt)
if(b[i].to != y)
{
if(f[b[i].to][3] == 0) sum1[x][0]--;
else sum[x][0] = sum[x][0]*ksm(f[b[i].to][3],mod-2)%mod;
if(f[b[i].to][1] == 0) sum1[x][1]--;
else sum[x][1] = sum[x][1]*ksm(f[b[i].to][1],mod-2)%mod;
if(f[b[i].to][6] == 0) sum1[x][4]--;
else sum[x][4] = sum[x][4]*ksm(f[b[i].to][6],mod-2)%mod;
if(fu > 1)
{
op = (sum1[x][0]==0)*sum[x][0],f[x][1] = (f[x][1]+(f[b[i].to][2]+f[b[i].to][4])*op%mod)%mod;
op = (sum1[x][1]==0)*sum[x][1],f[x][3] = (f[x][3]+(f[b[i].to][0]+f[b[i].to][5])*op%mod)%mod;
op = (sum1[x][4]==0)*sum[x][4],f[x][4] = (f[x][4]+(f[b[i].to][2]+f[b[i].to][4])*op%mod)%mod;
op = (sum1[x][4]==0)*sum[x][4],f[x][5] = (f[x][5]+(f[b[i].to][0]+f[b[i].to][5])*op%mod)%mod;
//正常转移
}
if(fu == 1)//为1时不需要枚举
{
f[x][1] = (f[x][1]+(f[b[i].to][2]+f[b[i].to][4])%mod)%mod;
f[x][3] = (f[x][3]+(f[b[i].to][0]+f[b[i].to][5])%mod)%mod;
f[x][4] = (f[x][4]+(f[b[i].to][2]+f[b[i].to][4])%mod)%mod;
f[x][5] = (f[x][5]+(f[b[i].to][0]+f[b[i].to][5])%mod)%mod;
}
if(fu > 2 && !l && sum1[x][4] <= 1)//显然必须要有大于两个点,两个点外面已经算了,若l为1就是特殊情况之前处理过了
{
if(f[b[i].to][6] == 0)//为0的话,就用sum_x_3的转移
{
op = sum[x][4],op1 = (sum[x][3]-op*(f[b[i].to][2]+f[b[i].to][4])%mod+mod)%mod;//减去他自己的
f[x][6] = (f[x][6]+op1*(f[b[i].to][0]+f[b[i].to][5])%mod)%mod;
}
else
{
op = (sum1[x][4]==0)*sum[x][4],op1 = (sum[x][5]-op*(f[b[i].to][2]+f[b[i].to][4])%mod+mod)%mod;//减去他自己的
f[x][6] = (f[x][6]+op1*(f[b[i].to][0]+f[b[i].to][5])%mod*ksm(f[b[i].to][6],mod-2)%mod)%mod;
}
}
if(f[b[i].to][3] == 0) sum1[x][0]++;
else sum[x][0] = sum[x][0]*f[b[i].to][3]%mod;
if(f[b[i].to][1] == 0) sum1[x][1]++;
else sum[x][1] = sum[x][1]*f[b[i].to][1]%mod;
if(f[b[i].to][6] == 0) sum1[x][4]++;
else sum[x][4] = sum[x][4]*f[b[i].to][6]%mod;//加回去
}
}
signed main()
{
// freopen("chess7.in","r",stdin);
// freopen("chess.out","w",stdout);
read(n);
for(int i = 1;i < n;i++) read(x),read(y),add(x,y),add(y,x);
dfs(1,0);
print((f[1][1]+f[1][3]+f[1][6])%mod); flush();
return 0;
}
```