求助拆系数 FFT 求 MTT

P4245 【模板】任意模数多项式乘法

zhiyangfan @ 2022-01-26 16:48:50

思路来自 cmd 的题解。

https://www.luogu.com.cn/blog/command-block/solution-p4245

代码:

#include <cmath>
#include <cstdio>
#include <algorithm>
typedef long double ld; typedef long long ll;
const int N = 2e6 + 10, B = 1 << 15; const ld PI = acos(-1.0);
struct cp
{
    ld x, y;
    cp(ld x = 0, ld y = 0) : x(x), y(y) { }
    cp operator+(const cp& c) { return cp(x + c.x, y + c.y); }
    cp operator-(const cp& c) { return cp(x - c.x, y - c.y); }
    cp operator*(const cp& c) { return cp(x * c.x - y * c.y, x * c.y + y * c.x); }
}F1[N], F2[N], G[N]; int ans[N], rev[N], lim, m;
inline void getLR(int n)
{
    lim = 1; m = 0; while (lim <= n) lim <<= 1, ++m;
    for (int i = 0; i < lim; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (m - 1));
}
inline void FFT(cp* c, int len, int on)
{
    for (int i = 0; i < len; ++i) if (i < rev[i]) std::swap(c[i], c[rev[i]]);
    for (int h = 2; h <= len; h <<= 1)
    {
        cp wn(cos(2 * PI / h), on * sin(2 * PI / h));
        for (int j = 0; j < len; j += h)
        {
            cp w(1, 0);
            for (int k = j; k < j + h / 2; ++k, w = w * wn)
            {
                cp u = c[k], t = w * c[k + h / 2];
                c[k] = u + t; c[k + h / 2] = u - t;
            }
        }
    }
    if (on == -1) for (int i = 0; i < len; ++i) c[i].x /= len, c[i].y /= len;
}
int main()
{
    int n, m, p; scanf("%d%d%d", &n, &m, &p);
    for (int i = 0, v; i <= n; ++i) scanf("%d", &v), F1[i] = cp(v / B, v % B), F2[i] = cp(v / B, -v % B);
    for (int i = 0, v; i <= m; ++i) scanf("%d", &v), G[i] = cp(v / B, v % B);
    getLR(n + m + 1); FFT(F1, lim, 1); FFT(F2, lim, 1); FFT(G, lim, 1);
    for (int i = 0; i < lim; ++i) F1[i] = F1[i] * G[i], F2[i] = F2[i] * G[i];
    FFT(F1, lim, -1); FFT(F2, lim, -1);
    for (int i = 0; i < n + m + 1; ++i)
    {
        ll a1b1, a1b2, a2b1, a2b2;
        a1b1 = (ll)floor((F1[i].x + F2[i].x) / 2 + 0.49) % p;
        a1b2 = (ll)floor((F1[i].y + F2[i].y) / 2 + 0.49) % p;
        a2b1 = ((ll)floor(F1[i].y + 0.49) - a1b2) % p;
        a2b2 = ((ll)floor(F2[i].x + 0.49) - a1b1) % p;
        ans[i] = ((a1b1 * B + (a1b2 + a2b1)) * B + a2b2) % p;
        (ans[i] += p) %= p; printf("%d ", ans[i]);
    }
    puts(""); return 0;
}

|