2huk
2024-11-16 21:59:09
https://atcoder.jp/contests/abc380/tasks/abc380_g
唉。
给定
n 的排列p 。你需要随机选择一个长度为k 的区间,并将其随机打乱。求期望逆序对。
不妨枚举打乱的区间为
#include "bits/stdc++.h"
using namespace std;
#define int long long
const int N = 2e5 + 10, P = 998244353;
int fpm(int a, int b) {
int res = 1;
while (b) {
if (b & 1) res = 1ll * res * a % P;
b >>= 1, a = 1ll * a * a % P;
}
return res;
}
int n, k, a[N];
int root[N], idx;
struct Node {
int l, r, v;
}tr[N * 40];
int build(int l, int r) {
int u = ++ idx;
if (l != r) {
int mid = l + r >> 1;
tr[u].l = build(l, mid);
tr[u].r = build(mid + 1, r);
}
return u;
}
void pushup(int u) {
tr[u].v = tr[tr[u].l].v + tr[tr[u].r].v;
}
int modify(int u, int l, int r, int x) {
int v = ++ idx;
tr[v] = tr[u];
if (l == r) tr[v].v ++ ;
else {
int mid = l + r >> 1;
if (x <= mid) tr[v].l = modify(tr[v].l, l, mid, x);
else tr[v].r = modify(tr[v].r, mid + 1, r, x);
pushup(v);
}
return v;
}
int query(int u, int tl, int tr, int l, int r) {
if (tl >= l && tr <= r) return ::tr[u].v;
int mid = tl + tr >> 1, res = 0;
if (l <= mid) res = query(::tr[u].l, tl, mid, l, r);
if (r > mid) res += query(::tr[u].r, mid + 1, tr, l, r);
return res;
}
int work(int l, int r, int x, int y) {
if (l > r || x > y) return 0;
return query(root[r], 0, n + 1, x, y) - (l ? query(root[l - 1], 0, n + 1, x, y) : 0);
}
int fac[N];
int pre[N], suf[N];
signed main() {
fac[0] = 1;
for (int i = 1; i < N; ++ i ) fac[i] = 1ll * fac[i - 1] * i % P;
cin >> n >> k;
root[n + 2] = build(0, n + 1);
root[0] = modify(root[n + 2], 0, n + 1, 0);
for (int i = 1; i <= n; ++ i ) {
cin >> a[i];
root[i] = modify(root[i - 1], 0, n + 1, a[i]);
}
a[n + 1] = a[n + 2] = n + 1;
root[n + 1] = modify(root[n], 0, n + 1, n + 1);
root[n + 2] = modify(root[n + 1], 0, n + 1, n + 1);
for (int i = 1; i <= n; ++ i ) {
pre[i] = work(0, i - 1, a[i] + 1, n + 1) + pre[i - 1];
}
for (int i = n; i; -- i ) {
suf[i] = work(i + 1, n + 1, 0, a[i] - 1) + suf[i + 1];
}
int res = 0;
long long lst1 = 0, lst2 = 0, lst3 = 0;
for (int l = 1; l + k - 1 <= n; ++ l ) {
lst1 = (lst1 - work(0, l - 2, a[l + k - 1] + 1, n + 1));
lst1 = (lst1 + work(l + k, n, 0, a[l - 1] - 1));
res = (res + lst1) % P;
}
for (int l = 1; l + k - 1 <= n; ++ l ) {
lst2 = (lst2 - work(0, l - 2, a[l - 1] + 1, n + 1));
lst2 = (lst2 + work(0, l - 1, a[l + k - 1] + 1, n + 1));
lst2 = (lst2 + work(l, l + k - 1, 0, a[l - 1] - 1));
lst2 = (lst2 - (a[l - 1] > a[l + k - 1]));
res = (res + lst2) % P;
}
for (int r = n; r - k + 1 >= 1; -- r ) {
lst3 = (lst3 - work(r + 2, n, 0, a[r + 1] - 1));
lst3 = (lst3 + work(r + 1, n, 0, a[r - k + 1] - 1));
lst3 = (lst3 + work(r - k + 1, r, a[r + 1] + 1, n + 1));
lst3 = (lst3 - (a[r - k + 1] > a[r + 1]));
res = (res + lst3) % P;
}
for (int l = 1, r = k; r <= n; ++ l, ++ r )
res = (res + pre[l - 1] + suf[r + 1]) % P;
cout << (1ll * res * fpm(n - k + 1, P - 2) % P + 1ll * k * (k - 1) % P * fpm(4, P - 2) % P) % P;
return 0;
}