绝望

P4245 【模板】任意模数多项式乘法

小菜鸟 @ 2018-08-26 19:40:06

对拆系数FFT绝望的我试了试三模数NTT。。。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int P1=998244353,P2=1004535809,P3=469762049,g=3;
int power(long long a,int b,int P)
{
    a%=P;
    int res=1;
    while(b)
    {
        (b&1)&&(res=(long long)res*a%P);
        b>>=1;
        a=(long long)a*a%P;
    }
    return res;
}
int rev[1<<18];

void read(int& x)
{
    x=0;
    bool sym=0;
    char c=getchar();
    while(c<48||c>57)sym|=(c==45),c=getchar();
    while(c>47&&c<58)x=(x<<1)+(x<<3)+(c^48),c=getchar();
    if(sym)x=~x+1;
}
void write(int x)
{
    if(x<0)putchar(45),x=~x+1;
    if(x>=10)write(x/10);
    putchar((x%10)^48);
}

void getrev(int bit)
{
    int n=1<<bit;
    for(int i=0;i<n;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
}
void ntt(int* a,int n,int P,int dft=1)
{
    int gi=power(g,P-2,P);
    for(int i=0;i<n;++i)if(i<rev[i])swap(a[i],a[rev[i]]);
    for(int step=1;step<n;step<<=1)
    {
        int wn=power(dft==1?g:gi,(P^1)/(step<<1),P);
        for(int j=0;j<n;j+=step<<1)
        {
            int wnk=1;
            for(int k=j;k<j+step;++k)
            {
                int x=a[k];
                int y=(long long)a[k+step]*wnk%P;
                a[k]=(x+y)%P;
                a[k+step]=(x-y+P)%P;
                wnk=(long long)wnk*wn%P;
            }
        }
    }
    if(dft==-1)
    {
        int inv=power(n,P-2,P);
        for(int i=0;i<n;++i)a[i]=(long long)a[i]*inv%P;
    }
}
int CRT(int a1,int a2,int a3,int P)
{
    int res=0;
    int inv1=power((long long)P2*P3,P1-2,P1)%P;
    int inv2=power((long long)P1*P3,P2-2,P2)%P;
    int inv3=power((long long)P1*P2,P3-2,P3)%P;
    res=(res+(long long)a1*P2%P*P3%P*inv1%P)%P;
    res=(res+(long long)a2*P1%P*P3%P*inv2%P)%P;
    res=(res+(long long)a3*P1%P*P2%P*inv3%P)%P;
    return res;
}
int n,m,P,a[1<<18],b[1<<18],_a[1<<18],_b[1<<18],res1[1<<18],res2[1<<18],res3[1<<18];
int main()
{
    read(n),read(m),read(P);
    for(int i=0;i<=n;++i)read(a[i]);
    for(int i=0;i<=m;++i)read(b[i]);
    int len=n+m+1,bit=0,s=1;
    while(s<len)++bit,s<<=1;
    getrev(bit);
    memcpy(_a,a,sizeof a);
    memcpy(_b,b,sizeof b);
    ntt(_a,s,P1),ntt(_b,s,P1);
    for(int i=0;i<s;++i)res1[i]=(long long)_a[i]*_b[i]%P1;
    ntt(res1,s,P1,-1);
    memcpy(_a,a,sizeof a);
    memcpy(_b,b,sizeof b);
    ntt(_a,s,P2),ntt(_b,s,P2);
    for(int i=0;i<s;++i)res2[i]=(long long)_a[i]*_b[i]%P2;
    ntt(res2,s,P2,-1);
    memcpy(_a,a,sizeof a);
    memcpy(_b,b,sizeof b);
    ntt(_a,s,P3),ntt(_b,s,P3);
    for(int i=0;i<s;++i)res3[i]=(long long)_a[i]*_b[i]%P3;
    ntt(res3,s,P3,-1);
    for(int i=0;i<len;++i)write(CRT(res1[i],res2[i],res3[i],P)),putchar(' ');
}

结果:5分,除#12外全WA。。。
我上辈子造了什么孽QAQ
有没有大佬帮我看看我的CRT啊


by Alioth_ @ 2019-01-23 10:27:50

我的也是


by Alioth_ @ 2019-01-23 10:28:06

#include<bits/stdc++.h>

using namespace std;

const int maxn=8e5+100;
const long long mod1=469762049;
const long long mod2=998244353;
const long long mod3=1004535809;
const long long M=1ll*mod1*mod2;
long long fast_mul(long long a,long long b,long long mod) {
    a %= mod, b %= mod;
    return ((a * b - (long long)((long long)((long double)a /mod*b+1e-3) * mod)) % mod + mod) % mod;
}
long long aa[maxn],bb[maxn],a[maxn],b[maxn],r[maxn],ans[5][maxn];
long long limit,n,m,l,p;
long long poww(long long a,long long b,long long mod)
{
    long long ans=1;
    while(b)
    {
        if(b&1)ans=(long long)ans*a%mod;
        b>>=1;
        a=(long long)a*a%mod;
    }
    return ans%mod;
}

long long inv(long long x,long long mod)
{
    return poww(x,mod-2,mod)%mod;
}

void NTT(long long * A,long long type,long long mod)
{
    long long pr=3;
    for(int i=0;i<limit;i++)if(r[i]>i)swap(A[r[i]],A[i]);
    for(int mid=1;mid<limit;mid<<=1)
    {
        long long wn=poww(type==-1?inv(pr,mod):pr,(mod-1)/(mid<<1),mod);//相当于原根gn
        long long w=1;                                          //gn^0和gn^p-1(由费马小定理得)和1同余模p
        for(int j=0,R=mid<<1;j<limit;j+=R)               //gn的1到p-1次方各不相同(原根的定义)
        {                                                //gn满足消去引理  
            long long w=1;                                       //由此可得 可以用gn代替wn 
            for(int k=0;k<mid;k++,w=(w*wn)%mod)
            {
                long long x=A[j+k],y=w*A[j+k+mid]%mod;
                A[j+k]=(x+y)%mod;
                A[j+k+mid]=(x-y+mod)%mod;
            }
        }
    }
}

int main()
{
    scanf("%lld%lld%lld",&n,&m,&p);
    for(int i=0;i<=n;i++)
        scanf("%lld",&aa[i]);
    for(int i=0;i<=m;i++)
        scanf("%lld",&bb[i]);
    limit=1;
    while(limit<=n+m)limit<<=1,l++;
    for(int i=0;i<=limit;i++)
        r[i]=r[i>>1]>>1|(i&1)<<l-1;
//------------------------------------------------------------------//
    memcpy(a,aa,sizeof(aa));
    memcpy(b,bb,sizeof(bb));
    NTT(a,1,mod1);
    NTT(b,1,mod1);
    for(int i=0;i<limit;i++)a[i]=a[i]*b[i]%mod1;
    NTT(a,-1,mod1);
    for(int i=0;i<=n+m;i++)ans[1][i]=(a[i]%mod1*inv(limit,mod1)%mod1+mod1)%mod1;

    memcpy(a,aa,sizeof(aa));
    memcpy(b,bb,sizeof(bb));
    NTT(a,1,mod2);
    NTT(b,1,mod2);
    for(int i=0;i<limit;i++)a[i]=a[i]*b[i]%mod2;
    NTT(a,-1,mod2);
    for(int i=0;i<=n+m;i++)ans[2][i]=(a[i]%mod2*inv(limit,mod2)%mod2+mod2)%mod2;

    memcpy(a,aa,sizeof(aa));
    memcpy(b,bb,sizeof(bb));
    NTT(a,1,mod3);
    NTT(b,1,mod3);
    for(int i=0;i<limit;i++)a[i]=a[i]*b[i]%mod3;
    NTT(a,-1,mod3);
    for(int i=0;i<=n+m;i++)ans[3][i]=(a[i]%mod3*inv(limit,mod3)%mod3+mod3)%mod3;
//------------------------------------------------------------------//
     for (int i = 0; i <= n + m; ++i) {
        long long A = (fast_mul(1ll * ans[1][i] * mod2 % M, inv(mod2%mod1,mod1), M) +
                       fast_mul(1ll * ans[2][i] * mod1 % M, inv(mod1%mod2,mod2), M)) % M;
        long long k = ((ans[3][i] - A) % mod3 + mod3) % mod3 * inv(mod1*mod2,mod3) % mod3;
      printf("%lld ", ((k % p) * (M % p) % p + A % p) % p);
    }
    return 0;
}

by Alioth_ @ 2019-01-23 10:28:25

只过了#12


|