题解:AT_abc380_g [ABC380G] Another Shuffle Window

qfy123

2024-11-18 08:27:59

Solution

Solution

在做此题之前,先要知道一个结论:

随机在 [1,n-k+1] 中选一个数 i,选到每个 i 的概率为 \frac{1}{n-k+1}。然后,对于选中的区间 [i, i + k - 1],我们先计算这个区间内逆序对的个数 cnt_i,整个序列逆序对个数 cnt_0,以及一个长度为 k 的区间随机打乱后的逆序对个数的期望 E = \frac{k(k-1)}{4},此时对期望的贡献为:

\frac{cnt_0 - cnt_i + E}{n-k+1}

那么答案就是:

\sum_{i=1}^{n-k+1} \frac{cnt_0 - cnt_i + E}{n-k+1} \pmod {998244353}

如何快速求出每个 cnt_i 呢?首先算出 cnt_1 的值,然后,每次将长度为 k 的区间向右滑动的时候,先减去 p_i 的贡献,再加上 p_{i+k} 的贡献,这样就能在 O(\log n) 的复杂度下将该区间向右滑动一次。具体详见代码。

最后,由于模数是质数,根据费马小定理,一个数 xp 意义下的逆元(\frac{1}{x}p 取模的结果)就是 x^{p-2} \pmod p,用快速幂求即可。

Code

#include<bits/stdc++.h>
#define int long long 
#define ull unsigned long long
#define ri register int
#define rep(i,j,k) for(ri i=(j);i<=(k);++i) 
#define per(i,j,k) for(ri i=(j);i>=(k);--i)
#define repl(i,j,k,l) for(ri i=(j);(k);i=(l))
#define IOS ios::sync_with_stdio(false);cin.tie(0);cout.tie(0)
#define pc(x) putchar(x)
#define fir first
#define se second 
#define MP pair<int,int>
#define pii pair<int,int>
#define PB push_back
#define lson p << 1
#define rson p << 1 | 1
#define ls(p) tr[p].ch[0]
#define rs(p) tr[p].ch[1]
using namespace std;
char BUFFER[100000],*P1(0),*P2(0);
#define gtc() (P1 == P2 && (P2 = (P1 = BUFFER) + fread(BUFFER,1,100000,stdin), P1 == P2) ? EOF : *P1++)
inline int R(){
    int x;char c;bool f = 0;
    while((c = gtc()) < '0') if(c == '-') f = 1;
    x = c ^ '0';
    while((c = gtc()) >= '0') x = (x << 3) + (x << 1) + (c ^ '0');
    return f?(~x + 1):x;
}
inline string Rs(){
    string str = "";
    char ch = gtc();
    while(ch == ' ' || ch == '\n' || ch == '\r') ch = gtc();
    while(ch != ' ' && ch != '\n' && ch != '\r' && ch > '\0') str += ch, ch = gtc();
    return str;
}
inline int rS(char s[]){
    int tot = 0; char ch = gtc();
    while(ch == ' ' || ch == '\n' || ch == '\r') ch = gtc();
    while(ch != ' ' && ch != '\n' && ch != '\r' && ch > '\0') s[++tot] = ch, ch = gtc();
    return tot; 
}
inline void O(int x){
    if(x < 0) pc('-'),x = -x;
    if(x < 10) pc(x + '0');
    else O(x / 10),pc(x % 10 + '0');
}
inline void out(int x,int type){
    if(type == 1) O(x),pc(' ');
    if(type == 2) O(x),pc('\n');
    if(type == 3) O(x);
}
inline void Ps(string s, int type){
    int m = s.length();
    rep(i, 0, m - 1) pc(s[i]); 
    if(type == 1) pc(' ');
    if(type == 2) pc('\n');
}
inline void pS(char *s, int type){
    int m = strlen(s + 1);
    rep(i, 1, m) pc(s[i]);
    if(type == 1) pc(' ');
    if(type == 2) pc('\n');
}
inline void OI(){
    freopen(".in","r",stdin);
    freopen(".out","w",stdout);
}
const int N = 2e5 + 10;
const int mod = 998244353;
int tr[N], n, k, p[N], ans, tot, cnt;
inline int lowbit(int x){
    return x & (-x);
}
inline void mdf(int x, int v){
    while(x <= N - 10){
        tr[x] += v;
        x += lowbit(x);
    }
}
inline int sum(int x){
    int ret = 0;
    while(x){
        ret += tr[x];
        x -= lowbit(x);
    }
    return ret;
}
inline int qry(int l, int r){return sum(r) - sum(l - 1);}
inline int ksm(int a, int b){
    int ret = 1;
    while(b){
        if(b & 1) ret = (ret * a) % mod;
        a = (a * a) % mod;
        b >>= 1;
    }
    return ret;
}
inline void solve(){
    n = R(), k = R();
    rep(i, 1, n) p[i] = R();
    rep(i, 1, n){
        tot += qry(p[i] + 1, n);
        mdf(p[i], 1);
        if(i == k) cnt = tot; 
    }
    per(i, n, k + 1) mdf(p[i], -1);//初始时,将 p[1...k] 加入树状数组中,cnt 记录的是 [1, k] 的逆序对数
    int E = k * (k - 1) % mod * ksm(4, mod - 2) % mod;
    rep(i, 1, n - k + 1){
        ans = (ans + (tot - cnt + E + mod) % mod * ksm(n - k + 1, mod - 2) % mod) % mod;
        if(i == n - k + 1) return out(ans, 2), void();
        cnt -= sum(p[i] - 1); mdf(p[i], -1); //减去 p[i] 的贡献
        cnt += qry(p[i + k] + 1, n); mdf(p[i + k], 1); //加上 p[i + k] 的贡献
        //在上面两次操作后,cnt 表示的区间滑动到了 [i + 1, i + k]
    }
}
signed main(){
    // OI();
    int T = 1;
    // T = R();
    while(T--) solve();
    return 0;
}