NTT求调

P4245 【模板】任意模数多项式乘法

可爱的小棉羊 @ 2024-01-29 16:29:41

#include<bits/stdc++.h>
using namespace std;

const int mod1=998244353,mod2=1004535809,mod3=469762049,g=3;
long long fpow(long long a,long long b,long long mod){
    long long ans=1;
    while(b){
        if(b&1)ans=ans*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return ans;
}
int res[140006],s,k,p;
long long a[140006],b[140006],am1[140005],am2[140005],am3[140005];
long long m,bm1[140006],bm2[140005],bm3[140005],n;
void ntt(long long *a,int n,int mod,bool fl){
    for(int i=0;i<n;i++)if(i<res[i])swap(a[i],a[res[i]]);

    for(int h=1;h<n;h<<=1){
        long long wn=fpow((fl? g:fpow(g,mod-2,mod)),(mod-1)/(h<<1),mod);
        for(int j=0;j<n;j+=h<<1){
            long long w=1;
            for(int k=0;k<h;k++,w=w*wn%mod){
                long long x=a[j+k],y=w*a[j+h+k]%mod;
                a[j+k]=(x+y)%mod;
                a[j+k+h]=(x-y+mod)%mod;
            }
        }
    }
    if(fl==0){
        long long inv=fpow(n,mod-2,mod);
        for(int i=0;i<n;i++)a[i]=a[i]*inv%mod;
    }
}
long long get(int i){
    int a=am1[i],b=am2[i],c=am3[i];
    __int128 k1=((b-a)%mod2+mod2);
    k1=k1*fpow(mod1,mod2-2,mod2);
    __int128 x=a+k1*mod1;
    __int128 k4=((c-x)%mod3+mod3)%mod3;
    k4=k4*fpow(mod1*mod2%mod3,mod3-2,mod3)%mod3;
    x=(x+k4*mod1*mod2)%((__int128)mod1*mod2*mod3);
    return (x%p+p)%p;
}
int main(){
    cin>>n>>m>>p;
    for(int i=0;i<n;i++){
        cin>>a[i];
        am1[i]=am2[i]=am3[i]=a[i];
    }
    for(int i=0;i<m;i++){
        cin>>b[i];
        bm1[i]=bm2[i]=bm3[i]=b[i];
    }
    s=1;
    while(s<=n+m){
        s<<=1;
        k++;
    }
    for(int i=0;i<s;i++)res[i]=(res[i>>1]>>1)|(1<<(k-1));
    ntt(am1,s,mod1,1);
    ntt(am2,s,mod2,1);
    ntt(am3,s,mod3,1);
    ntt(bm1,s,mod1,1);
    ntt(bm2,s,mod2,1);
    ntt(bm3,s,mod3,1);
    for(int i=0;i<s;i++){
        am1[i]=am1[i]*bm1[i]%mod1;
        am2[i]=am2[i]*bm2[i]%mod2;
        am3[i]=am3[i]*bm3[i]%mod3;
    }
    ntt(am1,s,mod1,0);
    ntt(am2,s,mod2,0);
    ntt(am3,s,mod3,0);
    for(int i=0;i<=n+m;i++)cout<<am1[i]<<" ";
} 

|