小超手123
2025-01-07 18:21:27
给定一棵
定义一次操作为,
在一次操作内,所有棋子的移动是同时进行的,并且需要遵循以下规则。
每条树边最多被一颗棋子经过。
移动后每个节点上至多有一颗棋子。
现在你需要统计
初始化
#include<bits/stdc++.h>
#define int long long
using namespace std;
#define getchar() (p1 == p2 && (p2 = (p1 = buf1) + fread(buf1, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf1[1 << 23], *p1 = buf1, *p2 = buf1, ubuf[1 << 23], *u = ubuf;
namespace IO
{
template<typename T>
void read(T &_x){_x=0;int _f=1;char ch=getchar();while(!isdigit(ch)) _f=(ch=='-'?-1:_f),ch=getchar();while(isdigit(ch)) _x=_x*10+(ch^48),ch=getchar();_x*=_f;}
template<typename T,typename... Args>
void read(T &_x,Args&...others){Read(_x);Read(others...);}
const int BUF=20000000;char buf[BUF],to,stk[32];int plen;
#define pc(x) buf[plen++]=x
#define flush(); fwrite(buf,1,plen,stdout),plen=0;
template<typename T>inline void print(T x){if(!x){pc(48);return;}if(x<0) x=-x,pc('-');for(;x;x/=10) stk[++to]=48+x%10;while(to) pc(stk[to--]);}
}
using namespace IO;
const int N = 2e5+10,mod = 998244353;
int n,head[N],f[N][20],cnt,x,y;
int sum[N][20],sum1[N][20],op,op1,X,Y,l;
int o,p;
/*
f_i_0 i为头,连接父亲的为红边
f_i_1 i为头,连接父亲的为黑边
f_i_2 i为尾,连接父亲的为红边
f_i_3 i为尾,连接父亲的为黑边
f_i_4 i为身子,开头是头
f_i_5 i为身子,开头是尾
f_i_6 i为身子,连接父亲的为黑边,也就是从子树中选出两条不同的链,一个是头,一个是尾,也就是f_u_5,f_u_0和f_v_4,f_v_2,剩下的都断开,取f_x_6
容易观察到,只保留黑边的话,头只能连尾,尾只能连头,身子只能连身子
*/
struct w
{
int to,nxt;
}b[N<<1];
inline int ksm(int x,int p)
{
x %= mod;//避免爆long long
int ans = 1;
while(p)
{
if((p&1)) ans = ans*x%mod;
x = x*x%mod;
p >>= 1;
}
return ans;
}
inline void add(int x,int y)
{
b[++cnt].nxt = head[x];
b[cnt].to = y;
head[x] = cnt;
}
void dfs(int x,int y)
{
f[x][0] = f[x][2] = 1; int fu = 0;
for(int i = 0;i <= 10;i++) sum[x][i] = 1,sum1[x][i] = 0;
sum[x][5] = sum[x][3] = 0;
for(int i = head[x];i;i = b[i].nxt)
if(b[i].to != y)
{
fu++;
dfs(b[i].to,x);
if(f[b[i].to][3] == 0) sum1[x][0]++;
else sum[x][0] = sum[x][0]*f[b[i].to][3]%mod;
if(f[b[i].to][1] == 0) sum1[x][1]++;
else sum[x][1] = sum[x][1]*f[b[i].to][1]%mod;
if(f[b[i].to][6] == 0) sum1[x][4]++;
else sum[x][4] = sum[x][4]*f[b[i].to][6]%mod;
}
for(int i = head[x];i;i = b[i].nxt)
if(b[i].to != y)
{
if(f[b[i].to][6] == 0) sum1[x][4]--;
else sum[x][4] = sum[x][4]*ksm(f[b[i].to][6],mod-2)%mod;
op = (sum1[x][4]==0)*sum[x][4],sum[x][5] = (sum[x][5]+op*(f[b[i].to][2]+f[b[i].to][4])%mod)%mod;
op = sum[x][4],sum[x][3] = (sum[x][3]+op*(f[b[i].to][2]+f[b[i].to][4])%mod)%mod;
if(f[b[i].to][6] == 0) sum1[x][4]++;
else sum[x][4] = sum[x][4]*f[b[i].to][6]%mod;
}
op = (sum1[x][0]==0)*sum[x][0],f[x][0] = f[x][0]*op%mod;
op = (sum1[x][1]==0)*sum[x][1],f[x][2] = f[x][2]*op%mod;
l = 0;
if(sum[x][4] == 2)
{
l = 1;
X = Y = 0;
for(int i = head[x];i;i = b[i].nxt)
if(b[i].to != y && f[b[i].to][6] == 0)
{
if(X == 0) X = b[i].to;
else Y = b[i].to;
}
f[x][6] = ((f[X][0]+f[X][5])%mod*((f[Y][2]+f[Y][4])%mod)%mod+(f[X][2]+f[X][4])%mod*((f[Y][0]+f[Y][5])%mod)%mod)*sum[x][4]%mod;
}
else if(fu == 2)
{
l = 1;
X = Y = 0;
for(int i = head[x];i;i = b[i].nxt)
if(b[i].to != y)
{
if(X == 0) X = b[i].to;
else Y = b[i].to;
}
f[x][6] = ((f[X][0]+f[X][5])%mod*((f[Y][2]+f[Y][4])%mod)%mod+(f[X][2]+f[X][4])%mod*((f[Y][0]+f[Y][5])%mod)%mod)%mod;
}
for(int i = head[x];i;i = b[i].nxt)
if(b[i].to != y)
{
if(f[b[i].to][3] == 0) sum1[x][0]--;
else sum[x][0] = sum[x][0]*ksm(f[b[i].to][3],mod-2)%mod;
if(f[b[i].to][1] == 0) sum1[x][1]--;
else sum[x][1] = sum[x][1]*ksm(f[b[i].to][1],mod-2)%mod;
if(f[b[i].to][6] == 0) sum1[x][4]--;
else sum[x][4] = sum[x][4]*ksm(f[b[i].to][6],mod-2)%mod;
if(fu > 1)
{
op = (sum1[x][0]==0)*sum[x][0],f[x][1] = (f[x][1]+(f[b[i].to][2]+f[b[i].to][4])*op%mod)%mod;
op = (sum1[x][1]==0)*sum[x][1],f[x][3] = (f[x][3]+(f[b[i].to][0]+f[b[i].to][5])*op%mod)%mod;
op = (sum1[x][4]==0)*sum[x][4],f[x][4] = (f[x][4]+(f[b[i].to][2]+f[b[i].to][4])*op%mod)%mod;
op = (sum1[x][4]==0)*sum[x][4],f[x][5] = (f[x][5]+(f[b[i].to][0]+f[b[i].to][5])*op%mod)%mod;
}
if(fu == 1)
{
f[x][1] = (f[x][1]+(f[b[i].to][2]+f[b[i].to][4])%mod)%mod;
f[x][3] = (f[x][3]+(f[b[i].to][0]+f[b[i].to][5])%mod)%mod;
f[x][4] = (f[x][4]+(f[b[i].to][2]+f[b[i].to][4])%mod)%mod;
f[x][5] = (f[x][5]+(f[b[i].to][0]+f[b[i].to][5])%mod)%mod;
}
if(fu > 2 && !l && sum1[x][4] <= 1)
{
if(f[b[i].to][6] == 0)
{
op = sum[x][4],op1 = (sum[x][3]-op*(f[b[i].to][2]+f[b[i].to][4])%mod+mod)%mod;
f[x][6] = (f[x][6]+op1*(f[b[i].to][0]+f[b[i].to][5])%mod)%mod;
}
else
{
op = (sum1[x][4]==0)*sum[x][4],op1 = (sum[x][5]-op*(f[b[i].to][2]+f[b[i].to][4])%mod+mod)%mod;
f[x][6] = (f[x][6]+op1*(f[b[i].to][0]+f[b[i].to][5])%mod*ksm(f[b[i].to][6],mod-2)%mod)%mod;
}
}
if(f[b[i].to][3] == 0) sum1[x][0]++;
else sum[x][0] = sum[x][0]*f[b[i].to][3]%mod;
if(f[b[i].to][1] == 0) sum1[x][1]++;
else sum[x][1] = sum[x][1]*f[b[i].to][1]%mod;
if(f[b[i].to][6] == 0) sum1[x][4]++;
else sum[x][4] = sum[x][4]*f[b[i].to][6]%mod;
}
}
signed main()
{
// freopen("chess7.in","r",stdin);
// freopen("chess.out","w",stdout);
read(n);
for(int i = 1;i < n;i++) read(x),read(y),add(x,y),add(y,x);
dfs(1,0);
print((f[1][1]+f[1][3]+f[1][6])%mod); flush();
return 0;
}