两份差不多的代码为什么一份会被卡精度,一份能过

P4245 【模板】任意模数多项式乘法

chenkuowen01 @ 2019-01-15 20:21:52

AC代码:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ld;
const ld pi=acos(-1);
const int MAX_N=1<<19|5;
struct C{
    ld x,y;
    inline C operator+(C b){ return (C){x+b.x,y+b.y}; }
    inline C operator-(C b){ return (C){x-b.x,y-b.y}; }
    inline C operator*(C b){ return (C){x*b.x-y*b.y,x*b.y+y*b.x}; }
    inline C operator!(){ return (C){x,-y}; }
};
inline void FFT(C a[],int n,int t){
    for(int i=0,pos=0;i<n;++i){
        if(i<pos) swap(a[i],a[pos]);
        for(int p=n>>1;(pos^=p)<p;p>>=1);
    }
    static C p[MAX_N]; p[0]=(C){1,0};
    for(int step=1;step<n;step<<=1){
        C w=(C){cos(pi/step),t*sin(pi/step)};
        for(int i=1;i<step;++i) p[i]=p[i-1]*w;
        for(int i=0;i<n;i+=step<<1)
            for(int j=i;j<i+step;++j){
                C x=a[j],y=a[j+step]*p[j-i];
                a[j]=x+y; a[j+step]=x-y;
            }
    }
    if(t==-1) for(int i=0;i<n;++i) a[i].x/=n,a[i].y/=n;
}
C a[MAX_N],b[MAX_N],c[MAX_N],d[MAX_N],p[MAX_N],q[MAX_N];
int getint(){
    int ret=0;
    char c=getchar();
    for(;c<'0'||c>'9';c=getchar());
    for(;c>='0'&&c<='9';c=getchar())
        ret=ret*10+c-'0';
    return ret;
}
int main(){
    int n,m,mod; scanf("%d%d%d",&n,&m,&mod);
    int m0=1<<15,top;
    for(top=1;top<=2*max(n,m);top<<=1);
    for(int i=0;i<=n;++i){
        int key=getint();
        p[i]=(C){(ld)(key/m0),(ld)(key%m0)};
    }
    for(int i=0;i<=m;++i){
        int key=getint();
        q[i]=(C){(ld)(key/m0),(ld)(key%m0)};
    }
    FFT(p,top,1); FFT(q,top,1);
    for(int i=0;i<top;++i){
        int j=i?top-i:0;
        b[i]=(p[i]-!p[j])*(C){0,-0.5}*(q[i]-!q[j])*(C){0,-0.5};
        a[i]=p[i]*q[i]+b[i];
    }
    FFT(a,top,-1); FFT(b,top,-1);
    for(int i=0;i<=n+m;++i){
        ll a1=(ll)(a[i].x+0.5)%mod;
        ll b1=(ll)(a[i].y+0.5)%mod;
        ll c1=(ll)(b[i].x+0.5)%mod;
        printf("%lld ",((a1*m0%mod*m0%mod+b1*m0%mod+c1)%mod+mod)%mod);
    }
    return 0;
}

WA掉的代码(60分):

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ld;
const ld pi=acos(-1);
const int MAX_N=1<<19|5;
struct C{
    ld x,y;
    inline C operator+(C b){ return (C){x+b.x,y+b.y}; }
    inline C operator-(C b){ return (C){x-b.x,y-b.y}; }
    inline C operator*(C b){ return (C){x*b.x-y*b.y,x*b.y+y*b.x}; }
    inline C operator!(){ return (C){x,-y}; }
};
inline void FFT(C a[],int n,int t){
    for(int i=0,pos=0;i<n;++i){
        if(i<pos) swap(a[i],a[pos]);
        for(int p=n>>1;(pos^=p)<p;p>>=1);
    }
    for(int step=1;step<n;step<<=1)
        for(int i=0;i<n;i+=step<<1)
            for(int j=i;j<i+step;++j){
                C w=(C){cos(pi*(j-i)/step),t*sin(pi*(j-i)/step)};
                C x=a[j],y=a[j+step]*w;
                a[j]=x+y; a[j+step]=x-y;
            }
    if(t==-1) for(int i=0;i<n;++i) a[i].x/=n,a[i].y/=n;
}
C a[MAX_N],b[MAX_N],c[MAX_N],d[MAX_N],p[MAX_N],q[MAX_N];
int getint(){
    int ret=0;
    char c=getchar();
    for(;c<'0'||c>'9';c=getchar());
    for(;c>='0'&&c<='9';c=getchar())
        ret=ret*10+c-'0';
    return ret;
}
int main(){
    int n,m,mod; scanf("%d%d%d",&n,&m,&mod);
    int m0=sqrt(mod)+1,top=1<<18;
    for(int i=0;i<=n;++i){
        int key=getint();
        p[i]=(C){(ld)(key/m0),(ld)(key%m0)};
    }
    for(int i=0;i<=m;++i){
        int key=getint();
        q[i]=(C){(ld)(key/m0),(ld)(key%m0)};
    }
    FFT(p,top,1); FFT(q,top,1);
    for(int i=0;i<top;++i){
        int j=i?top-i:0;
        C a1,b1,c1,d1;
        b[i]=(p[i]-!p[j])*(C){0,-0.5}*(q[i]-!q[j])*(C){0,-0.5};
        a[i]=p[i]*q[i]+b[i];
    }
    FFT(a,top,-1); FFT(b,top,-1);
    for(int i=0;i<=n+m;++i){
        ll a1=(ll)(a[i].x+0.5)%mod;
        ll b1=(ll)(a[i].y+0.5)%mod;
        ll c1=(ll)(b[i].x+0.5)%mod;
        printf("%lld ",((a1*m0%mod*m0%mod+b1*m0%mod+c1)%mod+mod)%mod);
    }
    return 0;
}

感觉除了m0的大小貌似没有什么其它区别了,为什么分数相差会这么大


by totorato @ 2019-03-01 18:52:47

的确存在这样的问题。我也不知道为什么,可能跟浮点数的存储方式有关


by AN94 @ 2019-07-26 21:28:11

M_0 不应该根据 mod 来设置,

而是根据系数的最大可能范围和多项式长度来设置的。 本题,系数 ai <= 1e9,多项式长度 1e5, 那么卷积之后结果(称result), result[i] <= 1e9*1e9*1e5 = 1e23double以及long double精度存不下, 那么拆系数后

A(x)=K_A(x)M_0+R_A(x) B(x)=K_B(x)M_0+R_B(x)

其中

K_A[j]=\lfloor\frac{A[j]}{M}\rfloor R_A[j]=A[j] \% M

也就是 A[j]=K_A[j] × M_0+R_A[j]

$$A(x)B(x) = M_0^2K_A(x)K_B(x)+M_0[(K_A(x)R_B(x) + K_B(x)R_A(x)] + R_A(x)R_B(x)$$ 为了不让FFT在计算这4个分解出来的多项式乘积时爆掉,长度没变,所以尽可能降低它们系数范围, 由于 `A[j]、B[j]` 范围是`1e9`,所以 $M_0$取1e4~1e5之间的数比较好(尽量接近 sqrt(1e9) 约为31623)。而 `1<<15 = 32768`符合要求。

by EMT__Mashiro @ 2019-12-27 14:25:28

读入的时候直接系数对p取模。
这样的话你的M_0设为\sqrt p是最优的,和长度没半毛钱关系。
当然取2^{15}加快取模也行。


|