震惊,居然卡精度

P4245 【模板】任意模数多项式乘法

Deep_Kevin @ 2018-09-22 14:00:03

这题拆系数FFT是可以过的,但是Pi要开long double,否则帅(毒)气(瘤)出题人会卡你精度

#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
using namespace std;

int n,m;
long long p;
struct complex{
    long double x,y;
    complex (long double xx=0,long double yy=0) {x=xx,y=yy;}
    complex operator-(const complex b)const {return (complex){x-b.x,y-b.y};}
    complex operator+(const complex b)const {return (complex){x+b.x,y+b.y};}
    complex operator*(const complex b)const {return (complex){x*b.x-y*b.y,x*b.y+y*b.x};}
}a[2][300010],b[2][300010],opx[300010],opy[300010];
long long ans[300010];
int r[300010];
int limit=1,l=0;
const long double Pi=acos(-1.0);
complex wn[300010];

void dft(complex *now,int idft){
    for(int i=0;i<limit;i++)
        if(i<r[i]) swap(now[i],now[r[i]]);
    for(int mid=2;mid<=limit;mid<<=1){
        complex wnow=wn[mid];
        wnow.y*=idft;
        for(int i=0;i<limit;i+=mid){
            complex w(1,0);
            for(int x=i,y=i+mid/2;x<i+mid/2;x++,w=w*wnow,y++){
                complex a=now[x],b=w*now[y];
                now[x]=a+b;
                now[y]=a-b;
            }
        }
    }
}

int main(){
    scanf("%d %d %lld",&n,&m,&p);
    while(limit<n+m+1) limit*=2,l++;
    for(int i=0;i<limit;i++) r[i]=(r[i>>1]>>1) | ((i&1)<<(l-1));
    for(int i=2;i<=limit;i++) wn[i]=(complex){std::cos(2.0*Pi/i),std::sin(2.0*Pi/i)};
    long long x;
    for(int i=0;i<=n;i++){
        scanf("%lld",&x);
        a[0][i].x=x%32768;
        a[1][i].x=x/32768;
    }
    for(int i=0;i<=m;i++){
        scanf("%lld",&x);
        b[0][i].x=x%32768;
        b[1][i].x=x/32768;
    }
    long long temp=(long long )1073741824%p;
    for(int i=0;i<=1;i++)
        for(int j=0;j<=1;j++){
            for(int k=0;k<=limit;k++) opx[k].x=opx[k].y=opy[k].x=opy[k].y=0;
            for(int k=0;k<=n;k++) opx[k]=a[i][k];
            for(int k=0;k<=m;k++) opy[k]=b[j][k];
            dft(opx,1); 
            dft(opy,1);
            for(int k=0;k<=limit;k++) opx[k]=opx[k]*opy[k];
            dft(opx,-1);
            for(int k=0;k<=limit;k++) opx[k].x=((long long)(opx[k].x/limit+0.5))%p;
            if(i+j==0) for(int k=0;k<=n+m;k++) ans[k]=(ans[k]+(long long)opx[k].x)%p;
            else if(i+j==1) for(int k=0;k<=n+m;k++) ans[k]=(ans[k]+(long long)opx[k].x*32768%p)%p;
            else if(i+j==2) for(int k=0;k<=n+m;k++) ans[k]=(ans[k]+(long long)opx[k].x*temp%p)%p;
        }
    for(int i=0;i<=n+m;i++) printf("%lld ",ans[i]);
}

|