自己造的数据全对,然后0分

P4245 【模板】任意模数多项式乘法

1lgorithm @ 2021-05-23 18:16:37

#include<iostream>
#include<cmath>
using namespace std;
typedef long long ll;
ll a[4000001],b[4000001];
int l,r[4000001],limit=1;
ll c1[4000001];
ll c2[4000001];
ll c3[4000001];
ll c[4000001];
const int mod1=469762049,g1=3;
const int mod2=998244353,g2=3;
const int mod3=1004535809,g3=3;
ll add(ll a,ll b,ll mod){return (a+b)%mod;}
ll sub(ll a,ll b,ll mod){return (a-b+mod)%mod;}
ll qpmod(ll a,ll b,ll p){
    ll ans=1%p;
    while(b){
        if(b&1) ans=ans*a%p;
        a=a*a%p;
        b>>=1;
    }
    return ans;
}
ll mul(ll a,ll b,ll mod){
    ll ans=0;
    while(b){
        if(b&1) ans=(ans+a)%mod;
        a=a*2%mod;
        b>>=1;
    }
    return ans;
}
ll inv(ll a,ll mod){
    return qpmod(a,mod-2,mod);
}
void NTT(long long *a,int type,int mod,int G){
    for(int i=0;i<limit;i++){
        if(i<r[i]) swap(a[i],a[r[i]]);
    }
    for(int mid=1;mid<limit;mid<<=1){
        ll wn=qpmod(G,(mod-1)/(mid<<1),mod);
        if(type==-1) wn=qpmod(wn,mod-2,mod);
        for(int R=mid<<1,j=0;j<limit;j+=R){
            ll w=1;
            for(int k=0;k<mid;++k,w=w*wn%mod){
                int x=a[j+k],y=w*a[j+mid+k]%mod;
                a[j+k]=add(x,y,mod);
                a[j+mid+k]=sub(x,y,mod);
            }
        }
    }
    if(type==1) return ;
    ll inv=qpmod(limit,mod-2,mod);
    for(int i=0;i<limit;++i) a[i]=a[i]*inv%mod;
}
void multi(int n,ll *a,int m,ll *b,ll *c,ll mod,ll G){
    l=0,limit=1;
    while(limit<=m+n) limit<<=1,++l;
    for(int i=0;i<limit;++i){
        r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    }
    NTT(a,1,mod,G);
    NTT(b,1,mod,G);
    for(int i=0;i<=limit;i++) c[i]=a[i]*b[i]%mod;
    NTT(a,-1,mod,G);
    NTT(b,-1,mod,G);
    NTT(c,-1,mod,G);
}
ll mul3(ll a,ll b,ll c,ll mod){
    return mul(mul(a,b,mod),c,mod);
}
ll exgcd(ll a,ll b,ll &x,ll &y){
    if(b==0){
        x=1,y=0;
        return a;
    }
    ll g=exgcd(b,a%b,x,y),t=x;
    x=y,y=t-a/b*x;
    return g;
}
int main(){
    int n,m,mod;
    cin>>n>>m>>mod;
    for(int i=0;i<=n;i++) cin>>a[i];
    for(int i=0;i<=m;i++) cin>>b[i];
    multi(n,a,m,b,c1,mod1,g1);
    multi(n,a,m,b,c2,mod2,g2);
    multi(n,a,m,b,c3,mod3,g3);
    ll invx,t,A,K,M=mod1*mod2;
    for(int i=0;i<=n+m;++i){
        invx=inv(mod1,mod2);
        t=invx*(c2[i]-c1[i]);
        A=c1[i]+t*mod1;
        K=mul(c3[i]-A,inv(M,mod3),mod);
        c[i]=mul(K,M,mod)+A;
        cout<<c[i]%mod<<' ';
    }
}

|