这踏马放 G 纯沙壁。

2huk

2024-11-16 21:59:09

Solution

https://atcoder.jp/contests/abc380/tasks/abc380_g

唉。

给定 n 的排列 p。你需要随机选择一个长度为 k 的区间,并将其随机打乱。求期望逆序对。

不妨枚举打乱的区间为 [l, l+k-1]。此时一个逆序对 i < j 会有 6 种情况。

  1. 考虑 $[l-1,l+k-2]$ 到 $[l, l+k-1]$ 的答案的增量。不难发现: - 满足 $i \le l - 2,j=l+k-1$ 的逆序对取消了贡献。 - 满足 $i=l-1,j \ge l+k$ 的逆序对加入了贡献。 求这两种逆序对的数量可以可持久化线段树。
  2. - 满足 $i \le l-2,j = l-1$ 的逆序对取消了贡献; - 满足 $i=l-1,l \le j \le l+k-1$ 的逆序对加入了贡献。 可持久化线段树。
  3. - 满足 $l-1 \le i \le l+k-2,j=l+k-1$ 的逆序对取消了贡献; - 满足 $i = l +k-1,j \ge l+k$ 的逆序对加入了贡献。
#include "bits/stdc++.h"

using namespace std;

#define int long long

const int N = 2e5 + 10, P = 998244353;

int fpm(int a, int b) {
  int res = 1;
  while (b) {
    if (b & 1) res = 1ll * res * a % P;
    b >>= 1, a = 1ll * a * a % P;
  }
  return res;
}

int n, k, a[N];

int root[N], idx;

struct Node {
  int l, r, v;
}tr[N * 40];

int build(int l, int r) {
  int u = ++ idx;
  if (l != r) {
    int mid = l + r >> 1;
    tr[u].l = build(l, mid);
    tr[u].r = build(mid + 1, r);
  }
  return u;
}

void pushup(int u) {
  tr[u].v = tr[tr[u].l].v + tr[tr[u].r].v;
}

int modify(int u, int l, int r, int x) {
  int v = ++ idx;
  tr[v] = tr[u];
  if (l == r) tr[v].v ++ ;
  else {
    int mid = l + r >> 1;
    if (x <= mid) tr[v].l = modify(tr[v].l, l, mid, x);
    else tr[v].r = modify(tr[v].r, mid + 1, r, x);
    pushup(v);
  }
  return v;
}

int query(int u, int tl, int tr, int l, int r) {
  if (tl >= l && tr <= r) return ::tr[u].v;
  int mid = tl + tr >> 1, res = 0;
  if (l <= mid) res = query(::tr[u].l, tl, mid, l, r);
  if (r > mid) res += query(::tr[u].r, mid + 1, tr, l, r);
  return res;
}

int work(int l, int r, int x, int y) {
  if (l > r || x > y) return 0;
  return query(root[r], 0, n + 1, x, y) - (l ? query(root[l - 1], 0, n + 1, x, y) : 0);
}

int fac[N];
int pre[N], suf[N];

signed main() {
  fac[0] = 1;
  for (int i = 1; i < N; ++ i ) fac[i] = 1ll * fac[i - 1] * i % P;

  cin >> n >> k;

  root[n + 2] = build(0, n + 1);
  root[0] = modify(root[n + 2], 0, n + 1, 0);
  for (int i = 1; i <= n; ++ i ) {
    cin >> a[i];
    root[i] = modify(root[i - 1], 0, n + 1, a[i]);
  }
  a[n + 1] = a[n + 2] = n + 1;
  root[n + 1] = modify(root[n], 0, n + 1, n + 1);
  root[n + 2] = modify(root[n + 1], 0, n + 1, n + 1);

    for (int i = 1; i <= n; ++ i ) {
        pre[i] = work(0, i - 1, a[i] + 1, n + 1) + pre[i - 1];
    }
    for (int i = n; i; -- i ) {
        suf[i] = work(i + 1, n + 1, 0, a[i] - 1) + suf[i + 1];
    }

  int res = 0;
  long long lst1 = 0, lst2 = 0, lst3 = 0;

  for (int l = 1; l + k - 1 <= n; ++ l ) {
    lst1 = (lst1 - work(0, l - 2, a[l + k - 1] + 1, n + 1));
    lst1 = (lst1 + work(l + k, n, 0, a[l - 1] - 1));

    res = (res + lst1) % P;
  }

  for (int l = 1; l + k - 1 <= n; ++ l ) {
    lst2 = (lst2 - work(0, l - 2, a[l - 1] + 1, n + 1));
    lst2 = (lst2 + work(0, l - 1, a[l + k - 1] + 1, n + 1));
    lst2 = (lst2 + work(l, l + k - 1, 0, a[l - 1] - 1));
    lst2 = (lst2 - (a[l - 1] > a[l + k - 1]));

    res = (res + lst2) % P;
  }

  for (int r = n; r - k + 1 >= 1; -- r ) {
    lst3 = (lst3 - work(r + 2, n, 0, a[r + 1] - 1));
    lst3 = (lst3 + work(r + 1, n, 0, a[r - k + 1] - 1));
    lst3 = (lst3 + work(r - k + 1, r, a[r + 1] + 1, n + 1));
    lst3 = (lst3 - (a[r - k + 1] > a[r + 1]));
    res = (res + lst3) % P;
  }

  for (int l = 1, r = k; r <= n; ++ l, ++ r )
    res = (res + pre[l - 1] + suf[r + 1]) % P;

  cout << (1ll * res * fpm(n - k + 1, P - 2) % P + 1ll * k * (k - 1) % P * fpm(4, P - 2) % P) % P;

  return 0;
}