MX-S6 T4 彩灯晚会

Petit_Souris

2024-11-17 13:22:23

Solution

验题人题解,大概做了 1.5h 不到过了。

首先这个 \sum cnt_i^2 典的不能再典了,用组合意义转化为「选出两条长度为 l 的同色链的方案数」。

把贡献拆开,枚举两条链,计算这两条链同色的方案数。假设他们重合了 c 个点,那么方案数为 k^{n-2l+c+1}。现在问题转化为,给定 c,计算选出两条长度为 l 的链,恰好重合 c 个点的方案数。

考虑容斥(二项式反演)。计算 g_c 表示钦定了 c 个点重合的方案数,二项式反演得到 f_c 表示恰好 c 个点重合的方案数。

如何计算 g?由于图是 DAG,我们可以在拓扑序上做 DP。具体来说,设 dp_{u,i,l_1,l_2} 表示目前钦定两条链重合于 u,已经钦定了 i 个重合点,两条链的目前长度分别为 l_1,l_2。转移枚举下一个重合点 v 和两条链在 u\to v 这段的长度 t_1,t_2,转移到 dp_{v,i+1,l_1+t_1,l_2+t_2}。这个转移形式意味着我们得预处理出 h_{u,v,i} 表示 u\to v 的长度为 i 的路径数量,还需要预处理出 s_{u,i}t_{u,i} 表示从 u 开始 / 结束于 u 的,长度为 i 的路径数量。

这个做法的时间复杂度为 \mathcal O(n^2l^5+n^3l),得到 76~84 分。

考虑优化。首先这个转移同时枚举 t_1,t_2 是不优的,因为转移系数没有同时和 t_1,t_2 有关的项,所以可以做分步转移,先转移 t_1 再转移 t_2。时间复杂度 \mathcal O(n^2l^4+n^3l),得到 84~92 分。

还能再给力点吗?发现我们最后其实并不关心每个 c 对应的方案数,我们只要算出对应的 k^c 就行了。这启发我们把所有的系数直接在转移过程中均摊掉。然而由于有二项式反演的存在,每转移一步就乘上 k 是错的。

正确的打开方式是考虑二项式反演中 i\to c 的贡献,把 k^c 乘进去之后是个二项式定理的形式,可以归纳得到正确的系数为 (k-1)^c。于是我们 DP 时不需要记录 i 这一位,直接每走一步乘上一个 (k-1) 就行了。时间复杂度 \mathcal O(n^2l^3+n^3l),可以获得 100 分。

稍微卡卡常数就不大了。

#include<bits/stdc++.h>
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
#define pii pair<ll,ll>
#define rep(i,a,b) for(ll i=(a);i<=(b);++i)
#define per(i,a,b) for(ll i=(a);i>=(b);--i)
using namespace std;
bool Mbe;
ll read(){
    ll x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
void write(ll x){
    if(x<0)putchar('-'),x=-x;
    if(x>9)write(x/10);
    putchar(x%10+'0');
}
const ll N=309,Mod=998244353;
ll typ,n,k,L,M;
ll f[N][N][22],ed[N][22],st[N][22],g[N][22][22],h[N][22][22],C[22][22];
ll ord[N],deg[N];
vector<pii>to[N];
ll pw(ll x,ll p){
    ll res=1;
    while(p){
        if(p&1)res=res*x%Mod;
        x=x*x%Mod,p>>=1;
    }
    return res;
}
bool Med;
int main(){
    freopen("party.in","r",stdin);
    freopen("party.out","w",stdout);
    cerr<<fabs(&Med-&Mbe)/1048576.0<<"MB\n";
    typ=read(),n=read(),k=read(),L=read(),M=read();
    rep(i,1,M){
        ll x=read(),y=read(),z=read();
        to[x].push_back({y,z}),deg[y]++;
    }
    rep(i,0,L){
        C[i][0]=1;
        rep(j,1,i)C[i][j]=(C[i-1][j]+C[i-1][j-1])%Mod;
    }
    queue<ll>q;
    rep(i,1,n){
        if(!deg[i])q.push(i);
    }
    while(!q.empty()){
        ll u=q.front();q.pop();
        ord[++ord[0]]=u;
        for(pii e:to[u]){
            ll v=e.first;
            deg[v]--;
            if(!deg[v])q.push(v);
        }
    }
    ord[0]=0;
    rep(i,1,n){
        ll u=ord[i];
        f[u][u][1]=1;
        rep(j,i,n){
            ll v=ord[j];
            rep(k,1,L-1){
                if(!f[u][v][k])continue;
                for(pii e:to[v]){
                    ll w=e.first,ww=e.second;
                    f[u][w][k+1]=(f[u][w][k+1]+f[u][v][k]*ww)%Mod;
                }
            }
        }
    }
    rep(i,1,n){
        rep(j,1,n){
            rep(k,1,L){
                st[i][k]=(st[i][k]+f[i][j][k])%Mod;
                ed[j][k]=(ed[j][k]+f[i][j][k])%Mod;
            }
        }
    }
    cerr<<"\n"<<clock()*1.0/CLOCKS_PER_SEC*1000<<"ms\n";
    ll ans=0;
    rep(i,1,n){
        rep(j,1,n)ans=(ans+f[i][j][L])%Mod;
    }
    ans=ans*ans%Mod;
    rep(i,1,n){
        ll u=ord[i];
        rep(l1,1,L){
            rep(l2,1,L)g[u][l1][l2]=(g[u][l1][l2]+ed[u][l1]*ed[u][l2]%Mod*(k-1))%Mod;
        }
        memset(h,0,sizeof(h));
        rep(j,i+1,n){
            ll v=ord[j];
            rep(l1,1,L){
                rep(l2,1,L){
                    if(!g[u][l1][l2])continue;
                    rep(t1,1,L-l1){
                        if(!f[u][v][t1+1])continue;
                        h[v][l1+t1][l2]=(h[v][l1+t1][l2]+g[u][l1][l2]*f[u][v][t1+1])%Mod;
                    }
                }
            }
        }
        rep(j,i+1,n){
            ll v=ord[j];
            rep(l1,1,L){
                rep(l2,1,L){
                    if(!h[v][l1][l2])continue;
                    rep(t2,1,L-l2){
                        if(!f[u][v][t2+1])continue;
                        g[v][l1][l2+t2]=(g[v][l1][l2+t2]+h[v][l1][l2]*f[u][v][t2+1]%Mod*(k-1))%Mod;
                    }
                }
            }
        }
    }
    rep(i,1,n){
        ll u=ord[i];
        rep(l1,1,L){
            rep(l2,1,L){
                if(!g[u][l1][l2])continue;
                ans=(ans+g[u][l1][l2]*st[u][L-l1+1]%Mod*st[u][L-l2+1])%Mod;
            }
        }
    }
    if(n-2*L+1<0)ans=ans*pw(pw(k,Mod-2),-(n-2*L+1))%Mod;
    else ans=ans*pw(k,n-2*L+1)%Mod;
    write(ans),putchar('\n');
    cerr<<"\n"<<clock()*1.0/CLOCKS_PER_SEC*1000<<"ms\n";
    return 0;
}