萌新刚学OI,求助WA 20pts

P4245 【模板】任意模数多项式乘法

Autofreeze @ 2021-02-21 11:50:24

RT,跑15次 NTT 然后合并,目测好像合并的时候错了

屑代码

#include<bits/stdc++.h>
#define N 2001001
#define re register
#define MAX 2001
#define eps 1e-10
using namespace std;
typedef long long ll;
typedef double db;
const ll mod1=998244353,mod2=1004535809, mod3=409762049,g=3;
inline void read(re ll &ret)
{
    ret=0;re bool pd=false;re char c=getchar();
    while(!isdigit(c)){(c=='-')&&(pd=true);c=getchar();}
    while(isdigit(c)){ret=(ret<<1)+(ret<<3)+(c^48);c=getchar();}    
    ret=pd?-ret:ret;
    return;
}
ll n,m,p,a[N],b[N],ans1[N],ans2[N],ans3[N],num,rev[N],inv1g,inv2g,inv3g;
inline ll qpow(re ll a,re ll b,re ll p)
{
    re ll ret=1;
    while(b)
    {
        if(b&1)
            ret*=a,ret%=p;
        b>>=1;
        a*=a;
        a%=p;
    }
    return ret%p;
}
inline ll inv(re ll x,re ll p)
{
    return qpow(x,p-2,p);
}
inline void ntt1(re ll a[],re ll n,re ll typ)
{
    re ll num=1,bit=0;
    while(num<n)
    {
        num<<=1;
        bit++;
    }
    for(re int i=0;i<n;i++)
    {
        rev[i]=(rev[i>>1]>>1)|((i&1)<<bit-1);
        if(i<rev[i])
            swap(a[i],a[rev[i]]);
    }
    for(re int mid=1;mid<n;mid<<=1)
    {
        re ll wn=qpow((typ==1)?g:inv1g,(mod1-1)/(mid<<1),mod1);
        for(re int j=0;j<n;j+=mid<<1)
        {
            re ll w=1;
            for(re int i=0;i<mid;i++,w*=wn,w%=mod1)
            {
                re ll x=a[i+j],y=a[i+j+mid]*w%mod1;
                a[i+j]=(x+y)%mod1,a[i+j+mid]=(x-y+mod1)%mod1;
            }
        }
    }
    return;
}
inline void ntt2(re ll a[],re ll n,re ll typ)
{
    re ll num=1,bit=0;
    while(num<n)
    {
        num<<=1;
        bit++;
    }
    for(re int i=0;i<n;i++)
    {
        rev[i]=(rev[i>>1]>>1)|((i&1)<<bit-1);
        if(i<rev[i])
            swap(a[i],a[rev[i]]);
    }
    for(re int mid=1;mid<n;mid<<=1)
    {
        re ll wn=qpow((typ==1)?g:inv2g,(mod2-1)/(mid<<1),mod2);
        for(re int j=0;j<n;j+=mid<<1)
        {
            re ll w=1;
            for(re int i=0;i<mid;i++,w*=wn,w%=mod2)
            {
                re ll x=a[i+j],y=a[i+j+mid]*w%mod2;
                a[i+j]=(x+y)%mod2,a[i+j+mid]=(x-y+mod2)%mod2;
            }
        }
    }
    return;
}
inline void ntt3(re ll a[],re ll n,re ll typ)
{
    re ll num=1,bit=0;
    while(num<n)
    {
        num<<=1;
        bit++;
    }
    for(re int i=0;i<n;i++)
    {
        rev[i]=(rev[i>>1]>>1)|((i&1)<<bit-1);
        if(i<rev[i])
            swap(a[i],a[rev[i]]);
    }
    for(re int mid=1;mid<n;mid<<=1)
    {
        re ll wn=qpow((typ==1)?g:inv3g,(mod3-1)/(mid<<1),mod3);
        for(re int j=0;j<n;j+=mid<<1)
        {
            re ll w=1;
            for(re int i=0;i<mid;i++,w*=wn,w%=mod3)
            {
                re ll x=a[i+j],y=a[i+j+mid]*w%mod3;
                a[i+j]=(x+y)%mod3,a[i+j+mid]=(x-y+mod3)%mod3;
            }
        }
    }
    return;
}
inline void solve1()
{
    ntt1(a,num,1);
    ntt1(b,num,1);
    for(re int i=0;i<num;i++)
        ans1[i]=a[i]*b[i]%mod1;
    ntt1(ans1,num,-1);
    ntt1(a,num,-1);
    ntt1(b,num,-1);
    for(re int i=0;i<=n;i++)
        a[i]=a[i]*inv(num,mod1)%mod1;
    for(re int i=0;i<=m;i++)
        b[i]=b[i]*inv(num,mod1)%mod1;
    return;
}
inline void solve2()
{
    ntt2(a,num,1);
    ntt2(b,num,1);
    for(re int i=0;i<num;i++)
        ans2[i]=a[i]*b[i]%mod2;
    ntt2(ans2,num,-1);
    ntt2(a,num,-1);
    ntt2(b,num,-1);
    for(re int i=0;i<=n;i++)
        a[i]=a[i]*inv(num,mod2)%mod2;
    for(re int i=0;i<=m;i++)
        b[i]=b[i]*inv(num,mod2)%mod2;
    return;
}
inline void solve3()
{
    ntt3(a,num,1);
    ntt3(b,num,1);
    for(re int i=0;i<num;i++)
        ans3[i]=a[i]*b[i]%mod3;
    ntt3(ans3,num,-1);
    ntt3(a,num,-1);
    ntt3(b,num,-1);
    for(re int i=0;i<=n;i++)
        a[i]=a[i]*inv(num,mod3)%mod3;
    for(re int i=0;i<=m;i++)
        b[i]=b[i]*inv(num,mod3)%mod3;
    return;
}
int main()
{
    inv1g=inv(g,mod1);
    inv2g=inv(g,mod2);
    inv3g=inv(g,mod3);
    read(n);
    read(m);
    read(p);
    for(re int i=0;i<=n;i++)
        read(a[i]);
    for(re int i=0;i<=m;i++)
        read(b[i]);
    num=1;
    while(num<n+m+1)
        num<<=1;
    solve1();
    solve2();
    solve3();
    for(re int i=0;i<n+m+1;i++)
    {
        re ll x1=ans1[i]*inv(num,mod1)%mod1,x2=ans2[i]*inv(num,mod2)%mod2,x3=ans3[i]*inv(num,mod3)%mod3;
        re ll k1,k2,k3;
        k1=((x2-x1)%mod2+mod2)%mod2*inv(mod1,mod2)%mod2;
        re ll x4=x1+k1*mod1;
        re ll lcm=mod1*mod2%mod3;
        re ll k4=((x3-x4)%mod3+mod3)%mod3*inv(lcm,mod3)%mod3;
        if(i)
            putchar(' ');
        printf("%lld",(x4%p+(k4%p)*(mod1*mod2%p)%p)%p);
    }
    putchar('\n');
    exit(0);
}

|