题解 P5367 【【模板】康托展开】

yangrunze

2020-02-26 16:22:47

Solution

大水题

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
int a[1000005],n;
int num[1000005];
int main(){
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        scanf("%d",&a[i]);
        num[i]=i;
    }
    int ans=1;
    do{
        bool flag=1; 
        for(int i=1;i<=n;i++){//判断相等 
            if(a[i]!=num[i]){
                flag=0;
                break;
            }
        }
        if(flag){
            printf("%d",ans);//相等就输出 
            return 0;
        }
        ans++;
    }while(next_permutation(num+1,num+1+n));//暴力枚举全排列 
    return 0;
}

这人疯了不要搭理他

咳咳,以上这种粗鲁蛮横的做法肯定是AC不了的,众所周知,这题我们要用一个神奇的方法——康托展开

康托展开讲的是什么故事呢?是这样的:对于一个1到n的排列\{a_1,a_2,\cdots,a_n\},比它小的排列有这些个:

\sum^n_{i=1}sum_{a_i}\times (n-i)!

sum_{a_i},就是在a_i后面的元素里比它小的元素个数,即\sum_{j=i}^n(a_j<a_i)

是不是大家开始一脸懵了呢?没关系,这是正常现象,因为上面那两句压根不是给人看的,咱们还是从一个具体的栗子讲起:

2,4,1,5,3

看这个排列,怎样求他是第几小的呢?别急,咱们一位一位去考虑

经过这么一通解释,现在大家应该都理解了康托展开的原理,上面的鬼畜柿子也应该懂了吧!那接下来就应该考虑代码实现的问题啦!

首先,阶乘这一块是珂以直接预处理出来的:

    for(int i=1;i<=n;i++)
    jc[i]=(jc[i-1]*i)%998244353;//阶乘的定义

然后,根据定义算出排名:

    int ans=0;
    for(int i=1;i<=n;i++){
      int sum=0;
      for(int j=i;j<=n;j++)
        if(a[i]<a[j])sum++;//统计sum
     ans=(ans+sum*jc[n-i])%998244353;//计算ans                       
    }
    printf("%d",ans+1);//别忘了+1

这么写的,你是不是又想TLE了???

好吧,以上代码没问题,但是这个题的数据很友(du)善(liu),O(n^2)的复杂度是跑不过的......那要多少呢?O(\log n)差不多,带\log的就那几样数据结构,经过精挑细选,我们就用常数小,操作方便,代码又好写的树状数组吧!

能优化的,无非就是 统计sum 的那一步......等会,树状数组是干啥用的来着?

单点修改,区间查询

那我们就要把sum的计算转化成区间和,这里有一个方法:一开始每个数都是1,如果某个数出现过了,就把它变成0,而最后要求的sum_{a_i},就是1a_i-1的区间和!(为什么?因为比它小的数都在它前面嘛)

咱们还是以刚才2,4,1,5,3的那个栗子来解释:

最后,就是写代码的时间啦!

#include<iostream>
#include<cstdio>
using namespace std;
typedef long long ll;//虽然不知道这题不开long long会不会见祖宗,但保险起见还是开一下吧
ll tree[1000005];//树状数组
int n;
  //树状数组的“老三件”:lowbit,修改,求和
int lowbit(int x){
    return x&-x;
}
void update(int x,int y){
    while(x<=n){
        tree[x]+=y;
        x+=lowbit(x);   
    }
}
ll query(int x){
    ll sum=0;
    while(x){
        sum+=tree[x];
        x-=lowbit(x);   
    }
    return sum;
}
const ll wyx=998244353;//懒人专用
ll jc[1000005]={1,1};//存阶乘的数组
int a[1000005];//存数的
int main(){
    scanf("%d",&n);
    for(int i=1;i<=n;i++){//预处理阶乘数组和树状数组
        jc[i]=(jc[i-1]*i)%wyx;
        update(i,1);
    }
    ll ans=0;
    for(int i=1;i<=n;i++){
        scanf("%d",&a[i]);
        ans=(ans+((query(a[i])-1)*jc[n-i])%wyx)%wyx;//计算ans
        update(a[i],-1);//把a[i]变成0(原来是1,减1不就是0嘛)
    }
    printf("%lld",ans+1);//别忘了+1
   return /*2333333333*/ 0;
}

总结一下: