求助拆系数 FFT 模板题

P4245 【模板】任意模数多项式乘法

phoenixzhan @ 2023-08-12 19:00:13

#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define pii pair<int, int>
#define mp make_pair
#define fi first
#define se second 
#define deb(var) cerr << #var << '=' << var << "; "
#define int long long
const double PI = acos(-1);
struct Complex {
    double a, b;
    const Complex operator + (const Complex &y) const {
        return Complex{a + y.a, b + y.b};
    }  
    const Complex operator - (const Complex &y) const {
        return Complex{a - y.a, b - y.b};
    }
    const Complex operator * (const Complex &y) const {
        return Complex{a * y.a - b * y.b, a * y.b + b * y.a};
    }
} a1[400010], a2[400010], b1[400010], b2[400010], A1[400010], A2[400010], B[400010];
int bit, tot, rev[400010];
void init(int n) {
    for (bit = 1; (1 << bit) < n; bit++); tot = 1 << bit;
    for (int i = 1; i < tot; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)); 
}
void FFT(Complex *a, int fac) {
    for (int i = 0; i < tot; i++)
        if (rev[i] > i) swap(a[rev[i]], a[i]);
    for (int len = 1; len < tot; len <<= 1) {
        Complex w = Complex{cos(PI / len), sin(PI / len) * fac};
        for (int i = 0; i < tot; i += len << 1) {
            Complex wk = Complex{1, 0};
            for (int j = 0; j < len; j++, wk = wk * w) {
                Complex x = a[i + j], y = a[i + j + len] * wk;
                a[i + j] = x + y, a[i + j + len] = x - y; 
            }
        }
    }
    if (fac == -1) for (int i = 1; i <= tot; i++) a[i].a /= tot, a[i].b /= tot;
}
void calc(Complex *a, Complex *b) {
    for (int i = 0; i < tot; i++) a[i] = a[i] * b[i];
}
int n, m, mod, blk = 15;
signed main() {
    cin >> n >> m >> mod;
    init(n + m + 1);
    for (int i = 0; i <= n; i++) {
        int a; cin >> a;
        int a1 = a >> blk, a2 = a & ((1 << blk) - 1);
        A1[i].a = a1, A1[i].b = a2; A2[i].a = a1, A2[i].b = -a2;
    }
    for (int i = 0; i <= m; i++) {
        int b; cin >> b;
        int b1 = b >> blk, b2 = b & ((1 << blk) - 1); B[i].a = b1, B[i].b = b2; 
    }
    FFT(A1, 1); FFT(A2, 1); FFT(B, 1);
    calc(A1, B); calc(A2, B); FFT(A1, -1); FFT(A2, -1);
    for (int i = 0; i <= n + m; i++) {
        A1[i].a = (int)(A1[i].a + 0.5);
        A2[i].a = (int)(A2[i].a + 0.5);
        A1[i].b = (int)(A1[i].b + 0.5);
        A2[i].b = (int)(A2[i].b + 0.5);
        int a1b1 = (A1[i].a + A2[i].a) / 2, a1b2 = (A1[i].b + A2[i].b) / 2, a2b1 = (A1[i].b - A2[i].b) / 2, a2b2 = (A2[i].a - A1[i].a) / 2;
        a1b1 %= mod, a1b2 %= mod, a2b1 %= mod, a2b2 %= mod;
        cout << ((a1b1 << (blk << 1)) + ((a1b2 + a2b1) << blk) + a2b2) % mod << " ";
    }
    return 0;
}

|