题解:AT_abc242_h [ABC242Ex] Random Painting

_maojun_

2025-01-07 12:00:52

Solution

实现了一下学长讲的 O(m^2\log n) 做法。

设状态为 S=\{(l_i,r_i)\},合法为这些区间的并达到 [1,n]

停时的期望等于所有不合法状态的“出现概率”与“离开期望步数”乘积的总和。

大小为 k 的集合的“离开期望步数”为 \dfrac m{m-k},只需求大小为 k 的不合法集合出现总概率。

区间双关键字排序后,设 f_{i,j,k} 表示考虑了前 i 个区间,覆盖了 [1,j],用了 k 的区间的方案数,则最后不合法的概率:

p_k=\dfrac{{m\choose k}-f_{m,n,k}}{m\choose k} $$f_{i,j,k}\gets f_{i-1,j,k}\\f_{i,\max\{j,r_i\},k+1}\gets f_{i-1,j,k},j\ge l_i-1$$ 可以做到 $O(nm^2)$。 ```cpp typedef pair<int,int> pi; #define fi first #define se second typedef long long ll; const int N=405,MOD=998244353; int n,m;pi a[N]; inline ll ksm(ll x,int y=MOD-2){ll r=1;for(;y;y>>=1,x=x*x%MOD)if(y&1)r=r*x%MOD;return r;} ll I[N],C[N]; inline void ad(int&x,int y){x+=y;x>=MOD&&(x-=MOD);} int f[N][N][N]; inline void main(){ scanf("%d%d",&n,&m); for(int i=1;i<=m;i++)scanf("%d%d",&a[i].fi,&a[i].se); sort(a+1,a+m+1); I[1]=1;for(int i=2;i<=m;i++)I[i]=I[MOD%i]*(MOD-MOD/i)%MOD; C[0]=1;for(int i=1;i<=m;i++)C[i]=C[i-1]*i%MOD*I[m-i+1]%MOD; f[0][0][0]=1; for(int i=1;i<=m;i++)for(int j=0;j<=n;j++)for(int k=0;k<i;k++)if(f[i-1][j][k]){ ad(f[i][j][k],f[i-1][j][k]); if(a[i].fi<=j+1)ad(f[i][max(j,a[i].se)][k+1],f[i-1][j][k]); } ll rs=0; for(int k=0;k<m;k++)rs=(rs+(1-f[m][n][k]*C[k])%MOD*I[m-k])%MOD; printf("%lld\n",(rs+MOD)*m%MOD); } ``` --- 三维状态,比较难以优化,但是可以把第三维写成生成函数形式:$F_{i,j}(x)=\sum f_{i,j,k}x^k$,则转移变为: $$F_{i,j}(x)\gets F_{i-1,j}(x)\\F_{i,\max\{j,r_i\}}(x)\gets xF_{i-1,j}(x),j\ge l_i-1$$ 最后代入 $m+O(1)$ 个 $x$ 把 $F_{m,n}$ 的系数插出来。 这个转移的过程形如单点修改区间乘法和查询区间和,线段树维护即可,复杂度 $O(m^2\log n)$。 --- 代码可能写得比较丑。 ```cpp typedef pair<int,int> pi; #define fi first #define se second typedef long long ll; const int N=405,MOD=998244353; int n,m;pi a[N]; inline ll ksm(ll x,int y=MOD-2){ll r=1;for(;y;y>>=1,x=x*x%MOD)if(y&1)r=r*x%MOD;return r;} ll I[N],C[N]; const int S=N<<2; ll vl[S],tg[S]; #define ls p<<1 #define rs p<<1|1 #define md (l+r>>1) #define Ls ls,l,md #define Rs rs,md+1,r #define al 1,0,n inline void pu(int p){vl[p]=(vl[ls]+vl[rs])%MOD;} inline void ch(int p,int k){vl[p]=vl[p]*k%MOD;tg[p]=tg[p]*k%MOD;} inline void pd(int p){if(tg[p]^1){ch(ls,tg[p]);ch(rs,tg[p]);tg[p]=1;}} void B(int p,int l,int r){vl[p]=!l;tg[p]=1;if(l==r)return;B(Ls);B(Rs);pu(p);} void U1(int p,int l,int r,int x,int k){ vl[p]=(vl[p]+k)%MOD;if(l==r)return;pd(p);x<=md?U1(Ls,x,k):U1(Rs,x,k); } void U2(int p,int l,int r,int L,int k){ if(L<=l)return ch(p,k);pd(p);if(L<=md)U2(Ls,L,k);U2(Rs,L,k);pu(p); } ll Q(int p,int l,int r,int L,int R){ if(L<=l&&r<=R)return vl[p];pd(p);return L>md?Q(Rs,L,R):R<=md?Q(Ls,L,R):(Q(Ls,L,R)+Q(Rs,L,R))%MOD; } const int G=3,iG=ksm(G); int Ln,Rv[512],*mG[10],wG; inline void INIT(){ wG=ksm(G,MOD-1>>9); for(int i=1;i<512;i++)Rv[i]=Rv[i>>1]>>1|(i&1)<<8; for(int k=0;k<9;k++){ const ll w=ksm(iG,MOD-1>>k+1); mG[k]=new int[1<<k];mG[k][0]=1; for(int i=1;i<1<<k;i++)mG[k][i]=mG[k][i-1]*w%MOD; } } inline void NTT(int*A){ for(int i=0;i<512;i++)if(i<Rv[i])swap(A[i],A[Rv[i]]); for(int k=0;k<9;k++)for(int i=0;i<512;i+=1<<k+1)for(int j=0;j<1<<k;j++){ int x=A[i|j],y=(ll)A[i|1<<k|j]*mG[k][j]%MOD; A[i|j]=(x+y)%MOD;A[i|1<<k|j]=(x-y+MOD)%MOD; } } int A[512]; inline void main(){ scanf("%d%d",&n,&m); for(int i=1;i<=m;i++)scanf("%d%d",&a[i].fi,&a[i].se); sort(a+1,a+m+1); I[1]=1;for(int i=2;i<=m;i++)I[i]=I[MOD%i]*(MOD-MOD/i)%MOD; C[0]=1;for(int i=1;i<=m;i++)C[i]=C[i-1]*i%MOD*I[m-i+1]%MOD; INIT(); for(int t=0,x=1;t<512;t++,x=(ll)x*wG%MOD){ B(al); for(int i=1,l=0;i<=m;i++){ l=max(l,a[i].fi-1); if(l<=a[i].se)U1(al,a[i].se,Q(al,l,a[i].se)*x%MOD); if(a[i].se^n)U2(al,a[i].se+1,x+1); } A[t]=Q(al,n,n); } NTT(A);ll iv=ksm(512),as=0; for(int k=0;k<m;k++)as=(as+(1-A[k]*iv%MOD*C[k])%MOD*I[m-k])%MOD; printf("%lld\n",(as+MOD)*m%MOD); } ```