求助 NTT 模板

P4245 【模板】任意模数多项式乘法

phoenixzhan @ 2023-08-22 18:35:46

#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define pii pair<int, int>
#define mp make_pair
#define fi first
#define se second 
#define deb(var) cerr << #var << '=' << var << "; "
#define int long long
int mod;
const int p1 = 998244353, p2 = 1004535809, p3 = 1998585857, g = 3;
#define cInt const Int &
int power(int x, int y, int mod);
struct Int {
    int a, b, c;
    Int() {
        a = b = c = 0;
    }
    Int(int x) {
        a = b = c = x;
    } 
    Int(int x, int y, int z) {
        a = x, b = y, c = z;
    }
    static Int reduce(Int x) {
        return Int((x.a + p1 * p1) % p1, (x.b + p2 * p2) % p2, (x.c + p3 * p3) % p3);
    }
    const Int friend operator + (cInt a, cInt b) {
        return reduce(Int(a.a + b.a, a.b + b.b, a.c + b.c));
    }
    const Int friend operator - (cInt a, cInt b) {
        return reduce(Int(a.a - b.a, a.b - b.b, a.c - b.c));
    }
    const Int friend operator * (cInt a, cInt b) {
        return reduce(Int(a.a * b.a, a.b * b.b, a.c * b.c));
    }
    int get() {
        const int k1 = (b + p2 - a) * power(p1, p2 - 2, p2) % p2, d = (k1 * p1 + a) % (p1 * p2);
        const int K1 = (c + p3 - d % p3) * power(p1 * p2, p3 - 2, p3) % p3; return (K1 * p1 % mod * p2 % mod + d) % mod;
    }
};
int power(int x, int y, int mod) {
    if (y < 0) y += mod - 1;
    int ans = 1;
    while (y) {
        if (y & 1) (ans *= x) %= mod; (x *= x) %= mod; y >>= 1; 
    } 
    return ans;
}
Int power(Int x, Int y) {
    return Int(power(x.a, y.a, p1), power(x.b, y.b, p2), power(x.c, y.c, p3));
}
int bit, tot, rev[400010];
void init(int n) {
    for (bit = 1; (1 << bit) < n; bit++); tot = (1 << bit);
    for (int i = 0; i < tot; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); 
}
void NTT(Int *a, int fac) {
    for (int i = 0; i < tot; i++)
        if (rev[i] > i) swap(a[i], a[rev[i]]);
    for (int len = 1; len < tot; len <<= 1) {
        Int w = power(Int(g), Int((p1 - 1) / (len << 1) * fac, 
                                  (p2 - 1) / (len << 1) * fac,
                                  (p3 - 1) / (len << 1) * fac));
        for (int i = 0; i < tot; i += len << 1) {
            Int wk = Int(1);
            for (int k = 0; k < len; k++, wk = wk * w) {
                Int x = a[i + k], y = a[i + k + len] * wk;
                a[i + k] = x + y; a[i + k + len] = x - y; 
            }
        }
    }
    if (fac == -1) {
        Int inv = power(Int(tot), Int(p1 - 2, p2 - 2, p3 - 2));
        for (int i = 0; i < tot; i++) a[i] = a[i] * inv;
    }
}
void calc(Int *a, Int *b) {
    for (int i = 0; i < tot; i++) a[i] = a[i] * b[i];
} 
int n, m; Int a[400010], b[400010];
signed main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n >> m >> mod;
    init(n + m + 1);
    for (int i = 0; i <= n; i++) cin >> a[i].a, a[i] = Int().reduce(Int(a[i].a));
    for (int i = 0; i <= m; i++) cin >> b[i].a, b[i] = Int().reduce(Int(b[i].a));
    NTT(a, 1); NTT(b, 1); calc(a, b); NTT(a, -1);
    for (int i = 0; i <= n + m; i++) cout << a[i].get() << " ";
    return 0;
}

|