三模数NTT 5pts 求de

P4245 【模板】任意模数多项式乘法

Meteorshower_Y @ 2022-08-16 16:48:23

#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
using namespace std;

namespace fast_IO
{
    #define FAST_IO
    #define IOSIZE 100000
    typedef long long ll;
    typedef double db;
    typedef long double ldb;
    typedef __int128_t i128;
    char ibuf[IOSIZE], obuf[IOSIZE];
    char *p1 = ibuf, *p2 = ibuf, *p3 = obuf;
    #ifdef ONLINE_JUDGE
        #define getchar() ((p1==p2)and(p2=(p1=ibuf)+fread(ibuf,1,IOSIZE,stdin),p1==p2)?(EOF):(*p1++))
        #define putchar(x) ((p3==obuf+IOSIZE)&&(fwrite(obuf,p3-obuf,1,stdout),p3=obuf),*p3++=x)
    #endif//fread in OJ, stdio in local

    #define isdigit(ch) (ch>47&&ch<58)
    #define isspace(ch) (ch<33&&ch!=EOF)

    struct fast_IO_t {
        ~fast_IO_t() {
            fwrite(obuf, p3-obuf, 1, stdout);
        }
        bool flag = false;
        operator bool() {
            return flag;
        }
    }io;

    template<typename T> inline T read() {
        T s = 0; int w = 1; char ch;
        while(ch=getchar(), !isdigit(ch)&&(ch!=EOF))
            if(ch == '-') w = -1;
        if(ch == EOF) return 0;
        while(isdigit(ch))
            s = s*10+ch-48, ch=getchar();
        if(ch == '.') {
            ll flt = 0; int cnt = 0;
            while(ch=getchar(), isdigit(ch))
                if(cnt < 18) flt=flt*10+ch-48, cnt++;
            s += (db)flt/pow(10,cnt);
        }
        return s *= w;
    }
    template<typename T> inline bool read(T &s) {
        s = 0; int w = 1; char ch;
        while(ch=getchar(), !isdigit(ch)&&(ch!=EOF))
            if(ch == '-') w = -1;
        if(ch == EOF) return false;
        while(isdigit(ch))
            s = s*10+ch-48, ch=getchar();
        if(ch == '.') {
            ll flt = 0; int cnt = 0;
            while(ch=getchar(), isdigit(ch))
                if(cnt < 18) flt=flt*10+ch-48, cnt++;
            s += (db)flt/pow(10,cnt);
        }
        return s *= w, true;
    }
    inline bool read(char &s) {
        while(s = getchar(), isspace(s));
        return s != EOF;
    }
    inline bool read(char *s) {
        char ch;
        while(ch=getchar(), isspace(ch));
        if(ch == EOF) return false;
        while(!isspace(ch))
            *s++ = ch, ch=getchar();
        *s = '\000';
        return true;
    } 
    template<typename T> void print(T x) {
        static int t[20]; int top = 0;
        if(x < 0) putchar('-'), x = -x;
        do { t[++top] = x%10; x /= 10; } while(x);
        while(top) putchar(t[top--]+48);
    }
    struct empty_type{}; int pcs = 8;
    empty_type setpcs(int cnt) {
        return pcs = cnt, empty_type();
    }
    inline void print(empty_type x){}
    inline void print(double x) {
        if(x < 0) putchar('-'), x = -x;
        x += 5.0 / pow(10,pcs+1);
        print((ll)(x)); x -= (ll)(x);
        if(pcs != 0) putchar('.');
        for(int i = 1; i <= pcs; i++)
        x *= 10, putchar((int)x+'0'), x -= (int)x;
    }
    inline void print(float x) {
        if(x < 0) putchar('-'), x = -x;
        x += 5.0 / pow(10,pcs+1);
        print((ll)(x)); x -= (ll)(x);
        if(pcs != 0) putchar('.');
        for(int i = 1; i <= pcs; i++)
        x *= 10, putchar((int)x+'0'), x -= (int)x;
    }
    inline void print(long double x) {
        if(x < 0) putchar('-'), x = -x;
        x += 5.0 / pow(10,pcs+1);
        print((i128)(x)); x -= (i128)(x);
        if(pcs != 0) putchar('.');
        for(int i = 1; i <= pcs; i++)
        x *= 10, putchar((int)x+'0'), x -= (int)x;
    }
    inline void print(char x) {
        putchar(x);
    }
    inline void print(char *x) {
        for(int i = 0; x[i]; i++)
            putchar(x[i]);
    }
    inline void print(const char *x) {
        for(int i = 0; x[i]; i++)
            putchar(x[i]);
    }
    #ifdef _GLIBCXX_STRING//string
        inline bool read(std::string& s) {
            s = ""; char ch;
            while(ch=getchar(), isspace(ch));
            if(ch == EOF) return false;
            while(!isspace(ch))
                s += ch, ch = getchar();
            return true;
        }
        inline void print(std::string x) {
            for(string::iterator i = x.begin(); i != x.end(); i++)
                putchar(*i);
        }
        inline bool getline(fast_IO_t &io, string s) {
            s = ""; char ch = getchar();
            if(ch == EOF) return false;
            while(ch != '\n' and ch != EOF)
                s += ch, ch = getchar();
            return true;
        }
    #endif
    #if __cplusplus >= 201103L
        template<typename T, typename... T1>
        inline int read(T& a, T1& ...other) {
            return read(a)+read(other...);
        }
        template<typename T, typename... T1>
        inline void print(T a, T1... other) {
            print(a); print(other...);
        }
    #endif
    template<typename T>
    fast_IO_t& operator >> (fast_IO_t &io, T &b) {
        return io.flag=read(b), io;
    }
    fast_IO_t& operator >> (fast_IO_t &io, char *b) {
        return io.flag=read(b), io;
    }
    template<typename T>
    fast_IO_t& operator << (fast_IO_t &io, T b) {
        return print(b), io;
    }
    #define cout io
    #define cin io
    #define endl '\n'
}
using namespace fast_IO;

// #define long long long
// #define int long
#define int128 __int128_t
#define int int128
#define long int128
const int MAXN = 2e5+10;
namespace NTT
{
    const int g = 3;
    const int p[4] = {0, 469762049, 998244353, 1004535809};
    const int inv[4] = {0, 156587350, 332748118, 334845270};
    int r[MAXN<<2];
    auto NTT(long *A, int limit, int mtp, int type) -> void;
    auto work(long *A, long *B, long *C, int len, int mtp) -> void;
}
int n, m, p, len;
long A[MAXN<<1], B[MAXN<<1];
long F[MAXN<<1], G[MAXN<<1];
long a[MAXN<<1], b[MAXN<<1], c[MAXN<<1];
long d[MAXN<<1], e[MAXN<<1], f[MAXN<<1];
auto ksm(long a, int b, const int p) -> long;
auto copyt(long *A, long *B, int len) -> void;
auto merge(long *A, long *B, long *C, int p1, int p2) -> void;
auto calck(long *A, long *B, long *C, int p1, int p2, int p3) -> void;
auto solve(long *A, long *B, long *C, int p1, int p2, int p ) -> void;

auto print(long *A, int len) -> void
{
    for(int i = 0; i <= len; i += 1) 
        print(A[i], ' ');
    print("\n\n");
}

auto main() -> signed
{
    // freopen("1.in", "r", stdin);
    // freopen("1.out", "w", stdout);
    read(n, m, p); len = n+m;
    for(int i = 0; i <= n; i += 1) read(F[i]);
    for(int i = 0; i <= m; i += 1) read(G[i]);
    copyt(A, F, n);copyt(B, G, m); NTT::work(A, B, a, len, 1);
    copyt(A, F, n);copyt(B, G, m); NTT::work(A, B, b, len, 2);
    copyt(A, F, n);copyt(B, G, m); NTT::work(A, B, c, len, 3);

    print(a, len);
    print(b, len);
    print(c, len);

    merge(a, b, d, NTT::p[1], NTT::p[2]);

    print(d, len);

    calck(c, d, e, NTT::p[1], NTT::p[2], NTT::p[3]);
    solve(d, e, f, NTT::p[1], NTT::p[2], p);
    for(int i = 0; i <= len; i += 1) print(f[i], ' ');
    // for(int i = 0; i <= len; i += 1) printf("%lld ", f[i]);
    return 0;
}
auto copyt(long *A, long *B, int len) -> void
{
    for(int i = 0; i <= 4e5; i += 1) A[i] = 0;
    for(int i = 0; i <= len; i += 1) A[i] = B[i];
}
auto merge(long *A, long *B, long *C, int p1, int p2) -> void
{
    const long mod = (long)p1*p2;
    long invp1 = ksm(p1, p2-2, p2);
    long invp2 = ksm(p2, p1-2, p1);
    for(int i = 0; i <= len; i += 1)
    C[i] = ((int128)(A[i])*p2%mod*invp2%mod+(int128)(B[i])*p1%mod*invp1%mod)%mod;
}
auto calck(long *A, long *B, long *C, int p1, int p2, int p3) -> void
{
    long invp1 = ksm(p1, p3-2, p3);
    long invp2 = ksm(p2, p3-2, p3);
    for(int i = 0; i <= len; i += 1)
    C[i] = (A[i]-B[i]+p3)%p3 *invp1%p3 *invp2%p3;
}
auto solve(long *A, long *B, long *C, int p1, int p2, int p ) -> void
{
    for(int i = 0; i <= len; i += 1)
    C[i] = (B[i]*p1%p*p2%p + A[i])%p;
}
auto ksm(long a, int b, const long p) -> long
{
    long ans = 1;
    for(; b; (a *= a) %= p, b >>= 1)
    if(b&1) (ans *= a) %= p; return ans;
}
namespace NTT
{
    auto NTT(long *A, int limit, int mtp, int type) -> void
    {
        long x, y, g0, gn;
        for(int i = 0; i < limit; i += 1)
            if(i < r[i]) swap(A[i], A[r[i]]);
        for(int midl = 1; midl < limit; midl <<= 1)
        {
            int len = midl<<1;
            g0 = ksm((type)?g:inv[mtp], (p[mtp]-1)/len, p[mtp]);
            for(int j = 0; j < limit; j += len)
            {
                gn = 1;
                for(int k = 0; k < midl; k += 1, (gn *= g0) %= p[mtp])
                {
                    x = A[j+k]%p[mtp];
                    y = gn*A[j+k+midl]%p[mtp];
                    A[j+k] = (x+y)%p[mtp];
                    A[j+k+midl] = (x-y+p[mtp])%p[mtp];
                }
            }
        }
    }
    auto work(long *A, long *B, long *C, int len, int mtp) -> void
    {
        int limit = 1, l = 0; long invlim;
        while(limit <= len) limit <<= 1, l += 1;
        for(int i = 1; i < limit; i += 1) 
            r[i] = (r[i>>1]>>1)|((i&1)<<(l-1));
        NTT(A, limit, mtp, 1); 
        NTT(B, limit, mtp, 1);
        for(int i = 0; i < limit; i += 1) 
            C[i] = A[i]*B[i]%p[mtp];
        NTT(C, limit, mtp, 0); 
        invlim = ksm(limit, p[mtp]-2, p[mtp]);
        for(int i = 0; i < limit; i += 1) 
            (C[i] *= invlim) %= p[mtp];
    }
}

by Rubidium_Chloride @ 2022-08-16 17:03:54

@Meteorshower_Y 学学三次变两次吧……


by Meteorshower_Y @ 2022-08-16 17:05:30

样例输入

10 32 2
316133033 283599618 204219867 370920642 252667717 163398849 63395880 181993402 450070248 746535125 606368805
573543188 358251194 62076044 440910171 702033347 873140273 489339950 341698431 663423219 897856975 225875896 878065331 749032224 232316733 246726391 196260400 652806755 15116632 476348590 976420144 276098450 618124309 32866383 799179899 266052150 341471514 89821467 167486323 872134567 792222039 924822776 912099917 618989538

样例输出

0 0 0 1 1 0 1 1 1 0 0 1 1 0 0 1 0 1 1 0 0 1 0 0 0 1 0 1 1 0 1 0 0 0 1 0 1 0 0 1 1 1 0 

by Meteorshower_Y @ 2022-08-16 17:10:38

@Reality_Creator de完挖的坑就去QwQ


by Meteorshower_Y @ 2022-08-16 18:03:44

过了...记得处处取模

磁铁终


|