宋小陀螺 @ 2024-08-02 08:02:17
#include<bits/stdc++.h>
using namespace std;
long long n,a[500005],tmp[500005],ans;
void msrt(long long lft,long long rgt)
{
if (rgt <= lft) return;
long long mid = (lft + rgt) >> 1;
msrt(lft,mid); msrt(mid + 1,rgt);
long long i = lft,j = mid + 1,k = lft;
while (i <= mid && j <= rgt)
{
if (a[i] <= a[j]) tmp[k++] = a[i++];
else
{
tmp[k++] = a[j++];
ans += i - lft;
}
}
while (i <= mid) tmp[k++] = a[i++];
while (j <= rgt)
{
tmp[k++] = a[j++];
ans += i - lft;
}
for (long long i = lft;i <= rgt;i++) a[i] = tmp[i];
}
int main()
{
cin >> n;
for (long long i = 1;i <= n;i++) cin >> a[i];
msrt(1,n);
cout << ans << endl;
return 0;
}
by LiujunjiaNC @ 2024-08-02 08:10:31
@宋小陀螺 AC
#include<bits/stdc++.h>
using namespace std;
long long n,a[500005],tmp[500005],ans;
void msrt(long long lft,long long rgt)
{
if (rgt <= lft) return;
long long mid = (lft + rgt) >> 1;
msrt(lft,mid); msrt(mid + 1,rgt);
long long i = lft,j = mid + 1,k = lft;
while (i <= mid && j <= rgt)
{
if (a[i] > a[j]){
ans += rgt - j+1;
tmp[k++] = a[i++];
}
else
{
tmp[k++] = a[j++];
}
}
while (i <= mid) tmp[k++] = a[i++];
while (j <= rgt)
{
tmp[k++] = a[j++];
}
for (long long i = lft;i <= rgt;i++) a[i] = tmp[i];
}
int main()
{
cin >> n;
for (long long i = 1;i <= n;i++) cin >> a[i];
msrt(1,n);
cout << ans << endl;
return 0;
}
by 宋小陀螺 @ 2024-08-02 08:16:39
@LiujunjiaNC thx已A