救救蒟蒻的拆系数FFT吧

P4245 【模板】任意模数多项式乘法

ButterflyDew @ 2018-12-17 22:20:24

我最开始以为是精度爆了

于是换了long double,预处理单位根,用stl的三角函数

然后还是挂

把数组开大,还是挂

怀疑是不是自己哪里写错掉了...

#include <cstdio>
#include <algorithm>
#include <cmath>
#include <vector>
#define ll long long
const int N=(1<<20)+10;
struct complex
{
    long double x,y;
    complex(){}
    complex(long double x,long double y){this->x=x,this->y=y;}
    complex friend operator +(complex n1,complex n2){return complex(n1.x+n2.x,n1.y+n2.y);}
    complex friend operator -(complex n1,complex n2){return complex(n1.x-n2.x,n1.y-n2.y);}
    complex friend operator *(complex n1,complex n2){return complex(n1.x*n2.x-n1.y*n2.y,n1.x*n2.y+n1.y*n2.x);}
}A[N],B[N];
int a[N],b[N],n,turn[N],m,len=1,L=-1;
ll ans[N],mod;
const ll M=32768;
const long double pi=std::acos(-1);
std::vector <long double> Sin[20];
std::vector <long double> Cos[20];
void FFT(complex *a,int typ)
{
    for(int i=1;i<len;i++) if(i<turn[i]) std::swap(a[i],a[turn[i]]);
    for(int l=1,le=1;le<len;le<<=1,++l)
    {
        for(int p=0;p<len;p+=le<<1)
        {
            for(int i=p;i<p+le;i++)
            {
                complex w=complex(Cos[l][i-p],typ*Sin[l][i-p]);
                complex tx=a[i],ty=w*a[i+le];
                a[i]=tx+ty;
                a[i+le]=tx-ty;
            }
        }
    }
}
void polymul(complex *a,complex *b)
{
    FFT(a,1),FFT(b,1);
    for(int i=0;i<len;i++) a[i]=a[i]*b[i];
    FFT(a,-1);
}
void MTT(int *a,int *b)
{
    while(len<=n+m) len<<=1,++L;
    for(int i=0;i<len;i++) turn[i]=turn[i>>1]>>1|(i&1)<<L;

    for(int i=0;i<len;i++) A[i]=complex(0,0),B[i]=complex(0,0);
    for(int i=0;i<n;i++) A[i]=complex(a[i]/M,0);
    for(int i=0;i<m;i++) B[i]=complex(b[i]/M,0);
    polymul(A,B);
    for(int i=0;i<len;i++) (ans[i]+=M*M%mod*(ll)(A[i].x/len+0.5))%=mod;

    for(int i=0;i<len;i++) A[i]=complex(0,0),B[i]=complex(0,0);
    for(int i=0;i<n;i++) A[i]=complex(a[i]/M,0);
    for(int i=0;i<m;i++) B[i]=complex(b[i]%M,0);
    polymul(A,B);
    for(int i=0;i<len;i++) (ans[i]+=M*(ll)(A[i].x/len+0.5))%=mod;

    for(int i=0;i<len;i++) A[i]=complex(0,0),B[i]=complex(0,0);
    for(int i=0;i<n;i++) A[i]=complex(a[i]%M,0);
    for(int i=0;i<m;i++) B[i]=complex(b[i]/M,0);
    polymul(A,B);
    for(int i=0;i<len;i++) (ans[i]+=M*(ll)(A[i].x/len+0.5))%=mod;

    for(int i=0;i<len;i++) A[i]=complex(0,0),B[i]=complex(0,0);
    for(int i=0;i<n;i++) A[i]=complex(a[i]%M,0);
    for(int i=0;i<m;i++) B[i]=complex(b[i]%M,0);
    polymul(A,B);
    for(int i=0;i<len;i++) (ans[i]+=(ll)(A[i].x/len+0.5))%=mod;
}
int main()
{
    scanf("%d%d%lld",&n,&m,&mod);++n,++m;
    long double lw=2.0;
    for(int le=1;le<=18;le++,lw=lw*2.0)
        for(int i=0;i<1<<le;i++)
            Sin[le].push_back(std::sin(pi*2/lw*i)),Cos[le].push_back(std::cos(pi*2/lw*i));
    for(int i=0;i<n;i++) scanf("%d",a+i);
    for(int i=0;i<m;i++) scanf("%d",b+i);
    MTT(a,b);
    for(int i=0;i<n+m-1;i++) printf("%lld ",(ans[i]+mod)%mod);
    return 0;
}

by Jameswood @ 2018-12-17 23:32:42

默哀三秒

fft是我一道迈不过去的坎


by ButterflyDew @ 2018-12-18 08:37:09

把代码改了一下

#include <cstdio>
#include <algorithm>
#include <cmath>
#include <vector>
#define ll long long
const int N=(1<<20)+10;
struct complex
{
    long double x,y;
    complex(){}
    complex(long double x,long double y){this->x=x,this->y=y;}
    complex friend operator +(complex n1,complex n2){return complex(n1.x+n2.x,n1.y+n2.y);}
    complex friend operator -(complex n1,complex n2){return complex(n1.x-n2.x,n1.y-n2.y);}
    complex friend operator *(complex n1,complex n2){return complex(n1.x*n2.x-n1.y*n2.y,n1.x*n2.y+n1.y*n2.x);}
}A[N],B[N],C[N],D[N],E[N],F[N],G[N],H[N];
int a[N],b[N],n,turn[N],m,len=1,L=-1;
ll ans[N],mod;
const ll M=32768;
const long double pi=std::acos(-1);
std::vector <long double> Sin[20];
std::vector <long double> Cos[20];
void FFT(complex *a,int typ)
{
    for(int i=1;i<len;i++) if(i<turn[i]) std::swap(a[i],a[turn[i]]);
    for(int l=1,le=1;le<len;le<<=1,++l)
    {
        for(int p=0;p<len;p+=le<<1)
        {
            for(int i=p;i<p+le;i++)
            {
                complex w=complex(Cos[l][i-p],typ*Sin[l][i-p]);
                complex tx=a[i],ty=w*a[i+le];
                a[i]=tx+ty;
                a[i+le]=tx-ty;
            }
        }
    }
}
void MTT(int *a,int *b)
{
    while(len<=n+m) len<<=1,++L;
    for(int i=0;i<len;i++) turn[i]=turn[i>>1]>>1|(i&1)<<L;

    for(int i=0;i<n;i++) A[i].x=a[i]/M,C[i].x=a[i]%M;
    for(int i=0;i<m;i++) B[i].x=b[i]/M,D[i].x=b[i]%M;
    FFT(A,1),FFT(B,1),FFT(C,1),FFT(D,1);
    for(int i=0;i<len;i++) E[i]=A[i]*B[i],F[i]=A[i]*D[i],G[i]=B[i]*C[i],H[i]=C[i]*D[i];
    FFT(E,-1),FFT(F,-1),FFT(G,-1),FFT(H,-1);
    for(int i=0;i<len;i++)
        ans[i]=(M*M%mod*(ll)(E[i].x/len+0.1)%mod+M*(ll)(F[i].x/len+0.1)%mod+M*(ll)(G[i].x/len+0.1)%mod+(ll)(H[i].x/len+0.1))%mod;
}
int main()
{
    scanf("%d%d%lld",&n,&m,&mod);++n,++m;
    long double lw=2.0;
    for(int le=1;le<=18;le++,lw=lw*2.0)
        for(int i=0;i<1<<le;i++)
            Sin[le].push_back(std::sin(pi*2/lw*i)),Cos[le].push_back(std::cos(pi*2/lw*i));
    for(int i=0;i<n;i++) scanf("%d",a+i);
    for(int i=0;i<m;i++) scanf("%d",b+i);
    MTT(a,b);
    for(int i=0;i<n+m-1;i++) printf("%lld ",(ans[i]+mod)%mod);
    return 0;
}

by 小菜鸟 @ 2018-12-19 21:40:30

四舍五入不是+0.5吗,大佬怎么+0.1?


by 小菜鸟 @ 2018-12-19 21:43:39

还有,为了减小对精度的需求M最好取\sqrt {mod}


by 小菜鸟 @ 2018-12-19 21:44:46

另外,推荐三模数NTT,不用管精度


by ButterflyDew @ 2019-01-15 18:49:26

上面代码不是一份0.5一份0.1嘛


by shadowice1984 @ 2019-01-15 18:51:13

@ButterflyDew

说实在的,感觉很多人的fft循环一点也不优雅,肥肠的丑


|