3模MTT求助

P4245 【模板】任意模数多项式乘法

sjwhsss @ 2024-12-19 15:03:50

RT,不开__int128会爆longlong,所有能模的都模了,修改取模的地方得分从0~50都有,不知道哪里有问题

#include <bits/stdc++.h>
#define int __int128
using namespace std;
const int maxn = 1e6+5 , mod1 = 998244353 , mod2 = 469762049 , mod3 = 1004535809 , G = 3 , G1 = 332748118 , G2 = 156587350 , G3 = 334845270;
int a1[maxn] , a2[maxn] , a3[maxn] , b1[maxn] , b2[maxn] , b3[maxn] , ans[maxn] , r[maxn];
int qpow(int a , int b , int mod)
{
    int res = 1;
    while(b)
    {
        if (b & 1)(res*=a)%=mod;
        (a*=a)%=mod;
        b>>=1;
    }
    return res;
}
void NTT(int *a , int lim , int t , int mod , int Gi)
{
    for (int i = 1; i < lim; i++) if (i < r[i])swap(a[i] , a[r[i]]);
    for (int i = 1; i < lim; i<<=1)
    {
        int ome = qpow(t == 1 ? G : Gi , (mod - 1)/(i<<1) , mod);
        for (int j = 0; j < lim; j+=i<<1)
        {
            int w = 1;
            for (int k = 0; k < i; k++ , (w*=ome)%=mod)
            {
                int x = a[j + k] , y = w * a[j + k + i] % mod;
                a[j + k] = (x + y)%mod;
                a[j + k + i] = (x - y + mod)%mod;
            }
        }
    }
    if (t == 1)return;
    int inv = qpow(lim , mod - 2 , mod);
    for (int i = 0; i < lim; i++) (a[i]*=inv)%=mod;
    return;
}
void Mul(int *a , int *b , int n , int m , int mod , int Gi)
{
    int lim = 1 , t = 0;
    while(lim <= n + m)lim<<=1,t++;
    for (int i = 1; i < lim; i++) r[i]=(r[i>>1]>>1)|((i&1)<<t-1);
    NTT(a , lim , 1 , mod , Gi) , NTT(b , lim , 1 , mod , Gi);
    for (int i = 0; i < lim; i++)(a[i]*=b[i])%=mod;
    NTT(a , lim , -1 , mod , Gi);
    return;
}
inline int read()
{
    char ch=getchar();int x=0;
    while(isdigit(ch)^1)ch=getchar();
    while(isdigit(ch))x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
    return x;
}
inline void print(int x)
{
    if (x > 9)print(x/10);
    putchar(x%10^48);
    return;
}
signed main ()
{
    int n=read() , m=read() , p=read();
    for (int i = 0; i <= n; i++) a2[i]=a3[i]=a1[i]=read();
    for (int i = 0; i <= m; i++) b2[i]=b3[i]=b1[i]=read();
    Mul(a1 , b1 , n , m , mod1 , G1);
    Mul(a2 , b2 , n , m , mod2 , G2);
    Mul(a3 , b3 , n , m , mod3 , G3);
    for (int i = 0; i < n + m + 1; i++)
    {
        int x = ((a2[i] - a1[i] + mod2 + mod2)%mod2 * qpow(mod1 , mod2 - 2 , mod2)%mod2 * mod1%(mod1*mod2) + a1[i])%(mod1*mod2);
        ans[i] = ((((a3[i] - x%mod3 + mod3 + mod3)%mod3%(mod1*mod2*mod3) * qpow(mod1%mod3 * mod2%mod3 , mod3 - 2 , mod3)%mod3)%(mod1*mod2*mod3) * (mod1*mod2)%(mod1*mod2*mod3))%(mod1*mod2*mod3) + x%(mod1*mod2*mod3))%p;
    }
    for (int i = 0; i < n + m + 1; i++) print(ans[i]),putchar(32);
    return 0;
}

|