题解:P11278 绝世丑角

bcdmwSjy

2024-11-15 07:44:07

Solution

神仙题,根本不会 Nim 积的人(包括我)也就只能得 13 分了。

对于一些二维 Nim 游戏,可以拆分成两维单独的 Nim 游戏然后求 Nim 积,ab 的 Nim 积定义为

x\otimes y=\operatorname{mex}\{(a\otimes b)\oplus(a\otimes y)\oplus(x\otimes b),0\le a<x,0\le b<y\}

其中 \otimes 为 Nim 积,\oplus 为异或。

在数学家们的不断努力下 (其实是我不会证),发现 Nim 积有如下性质:

$(x\otimes y)\otimes z=x\otimes(y\otimes z)$,具有结合律。 定义 $\operatorname{Fermat\ 2-power}$ 为 $x=2^{2^n},n\in\mathbb{N}$。 1. 一个 $\operatorname{Fermat\ 2-power}$ 与小于自己的数的 Nim 积为正常的乘积,$x\otimes y=xy(y<x)$。 2. 一个 $\operatorname{Fermat\ 2-power}$ 与自己的 Nim 积为 $\frac{3}{2}x$。 现在的问题是如何求两个数的 Nim 积,暴力求的复杂度是 $O(x^2y^2)$ 的,但是可以利用一些性质在 $O(\log x\log y)$ 的时间复杂度内求出。 设 $f(x,y)=x\otimes y$,特判边界后,可以拆位计算每个二进制的贡献。 设 $g(x,y)=2^x\otimes2^y$,那么 $f(x,y)$ 就等于 $x$ 和 $y$ 每一位 $g$ 的异或和。 接下来考虑如何求 $g(x,y)$,仍然是拆位,拆成一些 $\operatorname{Fermat\ 2-power}$,那么 $$g(x,y)=2^x\otimes2^y=\left(\bigotimes\limits_{x'\in x}2^{2^{x'}}\right)\otimes\left(\bigotimes\limits_{y'\in y}2^{2^{y'}}\right)$$ 从高到低考虑每一位,如果 $x$ 和 $y$ 在这一位上都为 $0$ 就可以跳过。 先处理 $x$ 和 $y$ 在这一位上只有一个为 $1$ 的情况: $\operatorname{Fermat\ 2-power}$ 和比自己小的数的 Nim 积就是乘积,从高到低计算即可。 再看 $x$ 和 $y$ 在这一位上都为 $1$ 的情况: 根据结合律,可以把这一位单独拿出来算,再根据 $\operatorname{Fermat\ 2-power}$ 的特点,可以得出 $\left(2^{2^{x_u'}}\right)\otimes\left(2^{2^{y_u'}}\right)=\frac{3}{2}\times2^{2^u}=3\times2^{2^u-1}

总结一下上面的情况,可以得出

\\ &=\left(\prod\limits_{i\in x\operatorname{xor}y}2^{2^i}\right)\otimes\left(\bigotimes\limits_{i\in x\operatorname{and}y}3\times2^{2^i-1}\right)\end{aligned}

前面的直接算,后面的用 f 递归计算即可。

每次只会遍历 xy 的二进制位,复杂度为 O(\log x\log y),记得给 g 记忆化一下,这样复杂度才对。

于是我们可以得到以下代码:

ll mem[128][128];

ll f(ll,ll);

ll g(int x,int y) {
    if (x==0 or y==0) return 1ll<<(x|y);
    if (mem[x][y]!=-1) return mem[x][y]; 
    ll res=1;
    for (int u=(x^y),i=0;(1<<i)<=u;i++){
        if ((u>>i)&1){
            res<<=(1ll<<i);
        }
    }
    for (int u=(x&y),i=0;(1<<i)<=u;i++){
        if ((u>>i)&1){
            res=f(res,3ll<<((1<<i)-1));
        }
    }
    return mem[x][y]=res;
}

ll f(ll x,ll y) {
    if (x==0 or y==0) return x|y;
    if (x==1 or y==1) return max(x,y);
    ll res=0;
    for (int i=0;(1ll<<i)<=x;i++){
        if ((x>>i)&1){
            for (int j=0;(1ll<<j)<=y;j++){
                if ((y>>j)&1){
                    res^=g(i,j);
                }
            }
        }
    }
    return res;
}

那这些有和这道题有什么关系呢,又经过数学家的不懈努力, 我们得到一个数和它自己的 Nim 积是有循环节的,并且长度是 O(\log x) 的,在本题中,循环节最长就是 32 的。

于是,我们可以先预处理出每个数的循环节,再用线段树维护加和还有异或和,修改时把数据旋转一位即可。

但是,这真的做完了吗,仔细分析一下,共有 n 个数,每个数的循环节是 \log w 的,求 Nim 积是时间复杂度是 O(\log^2 w) 的,光建树就是 O(n\log^3 w) 的,根本过不去。

但是我们发现 Nim 积的高位和低位可以分别计算,于是我们把它分成前 16 位和后 16 位预处理出自己和自己的 Nim 积和 Nim 积的 k 次方。

ll prod_high[65536],prod_low[65536],pw_high[65536][32],pw_low[65536][32];

ll prod(ll x){
    return prod_high[x>>16]^prod_low[x&65535];
}

ll pw(ll x,int k){
    return pw_high[x>>16][k]^pw_low[x&65535][k];
}

void init(){
    for (ll i=0;i<65536;i++){
        for (int j=0;j<32;j++){
            if (i&(i-1)){
                prod_high[i]=prod_high[i&(i-1)]^prod_high[i&(-i)];
                prod_low[i]=prod_low[i&(i-1)]^prod_low[i&(-i)];
            }else{
                prod_high[i]=f(i<<16,i<<16);
                prod_low[i]=f(i,i);
            }
        }
    }
    for (ll i=0;i<65536;i++){
        pw_high[i][0]=i<<16;
        pw_low[i][0]=i;
    }
    for (int i=0;i<65536;i++){
        for (int j=1;j<32;j++){
            pw_high[i][j]=prod(pw_high[i][j-1]);
            pw_low[i][j]=prod(pw_low[i][j-1]);
        }
    }
}

其中 prod 用来计算自己和自己的 Nim 积,pw 计算自己 Nim 积的 k 次方,这样我们就可以快速计算 Nim 积了。

完整代码如下:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;

ll mem[128][128];

ll f(ll,ll);

ll g(int x,int y) {
    if (x==0 or y==0) return 1ll<<(x|y);
    if (mem[x][y]!=-1) return mem[x][y]; 
    ll res=1;
    for (int u=(x^y),i=0;(1<<i)<=u;i++){
        if ((u>>i)&1){
            res<<=(1ll<<i);
        }
    }
    for (int u=(x&y),i=0;(1<<i)<=u;i++){
        if ((u>>i)&1){
            res=f(res,3ll<<((1<<i)-1));
        }
    }
    return mem[x][y]=res;
}

ll f(ll x,ll y) {
    if (x==0 or y==0) return x|y;
    if (x==1 or y==1) return max(x,y);
    ll res=0;
    for (int i=0;(1ll<<i)<=x;i++){
        if ((x>>i)&1){
            for (int j=0;(1ll<<j)<=y;j++){
                if ((y>>j)&1){
                    res^=g(i,j);
                }
            }
        }
    }
    return res;
}

ll prod_high[65536],prod_low[65536],pw_high[65536][32],pw_low[65536][32];

ll prod(ll x){
    return prod_high[x>>16]^prod_low[x&65535];
}

ll pw(ll x,int k){
    return pw_high[x>>16][k]^pw_low[x&65535][k];
}

void init(){
    for (ll i=0;i<65536;i++){
        for (int j=0;j<32;j++){
            if (i&(i-1)){
                prod_high[i]=prod_high[i&(i-1)]^prod_high[i&(-i)];
                prod_low[i]=prod_low[i&(i-1)]^prod_low[i&(-i)];
            }else{
                prod_high[i]=f(i<<16,i<<16);
                prod_low[i]=f(i,i);
            }
        }
    }
    for (ll i=0;i<65536;i++){
        pw_high[i][0]=i<<16;
        pw_low[i][0]=i;
    }
    for (int i=0;i<65536;i++){
        for (int j=1;j<32;j++){
            pw_high[i][j]=prod(pw_high[i][j-1]);
            pw_low[i][j]=prod(pw_low[i][j-1]);
        }
    }
}

int n,q;
ll a[250001];

struct Node{
    ll x[32],s[32];
    Node(){
        memset(x,0,sizeof(x));
        memset(s,0,sizeof(s));
    }
    inline void calc(int t){
        rotate(x,x+t,x+32);
        rotate(s,s+t,s+32);
    }
};

struct Tree{
    int l,r,tag;
    Node val;
};

Tree tr[524289];

#define ls (i<<1)
#define rs (i<<1|1)

void pushup(int i){
    for (int j=0;j<32;j++){
        tr[i].val.x[j]=tr[ls].val.x[j]^tr[rs].val.x[j];
        tr[i].val.s[j]=tr[ls].val.s[j]+tr[rs].val.s[j];
    }
}

void pushdown(int i){
    if (tr[i].tag){
        tr[ls].val.calc(tr[i].tag);
        tr[rs].val.calc(tr[i].tag);
        tr[ls].tag=(tr[ls].tag+tr[i].tag)&31;
        tr[rs].tag=(tr[rs].tag+tr[i].tag)&31;
        tr[i].tag=0;
    }
}

void build(int i,int l,int r){
    tr[i].l=l;
    tr[i].r=r;
    if (l==r){
        for (int j=0;j<32;j++){
            tr[i].val.s[j]=tr[i].val.x[j]=pw(a[l],j);
        }
        return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid);
    build(rs,mid+1,r);
    pushup(i);
}

void update(int i,int l,int r){
    if (tr[i].l>=l and tr[i].r<=r){
        tr[i].val.calc(1);
        tr[i].tag=(tr[i].tag+1)&31;
        return;
    }
    pushdown(i);
    if (tr[ls].r>=l) update(ls,l,r);
    if (tr[rs].l<=r) update(rs,l,r);
    pushup(i);
}

ll querys(int i,int l,int r){
    if (tr[i].l>=l and tr[i].r<=r) return tr[i].val.s[0];
    pushdown(i);
    ll ans=0;
    if (tr[ls].r>=l) ans+=querys(ls,l,r);
    if (tr[rs].l<=r) ans+=querys(rs,l,r);
    return ans;
}

ll queryx(int i,int l,int r){
    if (tr[i].l>=l and tr[i].r<=r) return tr[i].val.x[0];
    pushdown(i);
    ll ans=0;
    if (tr[ls].r>=l) ans^=queryx(ls,l,r);
    if (tr[rs].l<=r) ans^=queryx(rs,l,r);
    return ans;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    memset(mem,-1,sizeof(mem));
    init(); 
    cin>>n>>q;
    for (int i=1;i<=n;i++){
        cin>>a[i];
    }
    build(1,1,n);
    while (q--){
        int op,l,r;
        cin>>op>>l>>r;
        if (op==1){
            update(1,l,r);
        }else if (op==2){
            cout<<queryx(1,l,r)<<"\n";
        }else if (op==3){
            cout<<querys(1,l,r)<<"\n";
        }
    }
    return 0;
}

闲话:赛时写 13 分暴力时错了好多发,原因是线段树传参没开 long long,然后赛后写正解时用了之前的代码,导致线段树区间修改一直递归到叶子节点然后 T 到 13 分,卡了好久常才意识到我自己写错了,以后记得注意一下。