中国剩余定理,求查错

P4245 【模板】任意模数多项式乘法

阿尔托莉雅丶 @ 2022-02-27 20:18:38

M = \prod_{i = 1}^n m_iM_i = \frac {M}{m_i}, M_i^{-1} M_i \equiv 1 \ (mod\ m_i)

x \equiv \sum_{i = 1}^n M_iM_i^{-1}a_i, 并且解数为 1

#include <iostream>
#include <cstring>
using namespace std;
const int N = 2e5 + 5;   //remember to modify the range of the data!!
// const int mod = 998244353;
const int p[3] = {998244353, 1004535809, 469762049};
const int G = 3; // 他们共同的原根
typedef long long ll;

int n, m, t;
int limit = 1; //大于等于结果系数个数的最小2的幂
int bitnum = 0; //上述幂的次数
int r[N << 1]; //记录反转二进制后的值
int gn[N << 1];
ll f[N << 1], g[N << 1], a[N << 1], b[N << 1], ans[3][N << 1];
ll M[3];
ll invM[3];

ll qpow(ll a, ll p, int mod)
{
    ll res = 1;
    while(p)
    {
        if(p & 1)
            res = res * a % mod;
        a = a * a % mod;
        p >>= 1;
    }
    return res;
}

void NTT(ll a[], int type, int p)//type = 1 ntt, type = -1 intt
{
    for(int i = 0; i < limit; i++) 
    if(i < r[i])  //为了防止前面交换后面又交换回来所以用 < 而不用 !=
        swap(a[i], a[r[i]]);
    for(int mid = 1; mid < limit; mid <<= 1) //从底层往上合并 枚举待合并区间长度的一半
    {
        //最开始是两个长度为1的序列合并,mid = 1;
        //这里要乘二,因为fft里分子是2pi消掉了一个2所以那点没乘,但是这里要乘
        // ll Gn = qpow(type == 1 ? G : invG, (p - 1) / (mid << 1));
        ll Gn = type == 1 ? gn[limit / (mid << 1)] : gn[limit - limit / (mid << 1)];
        for(int len = mid << 1, pos = 0; pos < limit; pos += len)
        {
            ll w = 1; //幂,一直乘,得到平方,三次方...
            for(int k = 0; k < mid; k++, w = w * Gn % p)
            {
                int x = a[pos + k]; //左边部分
                int y = w * a[pos + mid + k] % p; //右边部分
                a[pos + k] = (x + y) % p; //左边加
                a[pos + mid + k] = ((x - y) % p + p) % p; //右边减
            }
        }
    }
    if(type == 1)
        return; 
    ll inv = qpow(limit, p - 2, p);
    for(int i = 0; i < limit; i++)//逆ntt最后要除以limit也就是补成了2的
        a[i] = a[i] * inv % p;     //整数幂的那个N,将点值转换为系数
}
void init(int p, int deg)
{
    // 补成2的整次幂, 注意初始化
    for(limit = 1, bitnum = 0; limit < deg + 1; limit <<= 1)
        bitnum++;
    for(int i = 0; i < limit; i++) //二进制反转
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (bitnum - 1));
    gn[0] = 1;
    gn[1] = qpow(G, (p - 1) / limit, p);
    for(int i = 2; i <= limit; i++)
        gn[i] = 1ll * gn[i - 1] * gn[1] % p;
}

void polmul(ll a[], ll b[], int p, int deg)
{
    init(p, deg);  // 初始化limit和原根
    NTT(a, 1, p);
    NTT(b, 1, p);
    for(int i = 0; i < limit; i++)
        a[i] = a[i] * b[i] % p;
    NTT(a, -1, p);
}
void CRT(int ansMod)
{
    //不同模数下计算
    memcpy(ans[0], a, sizeof(a[0]) * (n + 1));
    memcpy(g, b, sizeof(b[0]) * (m + 1));
    polmul(ans[0], g, p[0], n + m);

    memcpy(ans[1], a, sizeof(a[0]) * (n + 1));
    memcpy(g, b, sizeof(b[0]) * (m + 1));
    polmul(ans[1], g, p[1], n + m);

    memcpy(ans[2], a, sizeof(a[0]) * (n + 1));
    memcpy(g, b, sizeof(b[0]) * (m + 1));
    polmul(ans[2], g, p[2], n + m);
    //CRT合并
    for(int i = 0; i < 3; i++)
    {
        M[i] = ll(p[(i + 2) % 3]) * p[(i + 1) % 3] % p[i];
        invM[i] = qpow(M[i], p[i] - 2, p[i]);
        // cout << M[i] << ' ' << invM[i] << '\n';
    }

    for(int i = 0; i <= n + m; i++)
    {
        ll res = 0;
        for(int j = 0; j < 3; j++)
            res = (res + M[j] * invM[j] % ansMod * ans[j][i] % ansMod) % ansMod;
        ans[0][i] = res;
    }
}

int main(void)
{
    ll Mod;
    cin >> n >> m >> Mod;
    for(int i = 0; i <= n; i++)
        cin >> a[i];
    for(int i = 0; i <= m; i++)
        cin >> b[i];
    CRT(Mod);

    for(int i = 0; i <= n + m; i++)
        cout << ans[0][i] << ' ';
    return 0;
}

|