yhx-12243 的 NTT 究竟写了些什么(详细揭秘)

moongazer

2021-04-02 16:36:37

Personal

这是 yhx-12243 的 NTT

inline int & reduce(int &x) {return x += x >> 31 & mod;}
inline int & neg(int &x) {return x = (!x - 1) & (mod - x);}
u64 PowerMod(u64 a, int n, u64 c = 1) {for (; n; n >>= 1, a = a * a % mod) if (n & 1) c = c * a % mod; return c;}
namespace poly_base {
    int l, n; u64 iv; vec w2;
    void init(int n = N, bool dont_calc_factorials = true) {
        int i, t;
        for (inv[1] = 1, i = 2; i < n; ++i) inv[i] = u64(mod - mod / i) * inv[mod % i] % mod;
        if (!dont_calc_factorials) for (*finv = *fact = i = 1; i < n; ++i) fact[i] = (u64)fact[i - 1] * i % mod, finv[i] = (u64)finv[i - 1] * inv[i] % mod;
        t = min(n > 1 ? lg2(n - 1) : 0, 21),
        *w2 = 1, w2[1 << t] = PowerMod(31, 1 << (21 - t));
        for (i = t; i; --i) w2[1 << (i - 1)] = (u64)w2[1 << i] * w2[1 << i] % mod;
        for (i = 1; i < n; ++i) w2[i] = (u64)w2[i & (i - 1)] * w2[i & -i] % mod;
    }
    inline void NTT_init(int len) {n = 1 << (l = len), iv = mod - (mod - 1) / n;}
    void DIF(int *a) {
        int i, *j, *k, len = n >> 1, R, *o;
        for (i = 0; i < l; ++i, len >>= 1)
            for (j = a, o = w2; j != a + n; j += len << 1, ++o)
                for (k = j; k != j + len; ++k)
                    R = (u64)*o * k[len] % mod, reduce(k[len] = *k - R), reduce(*k += R - mod);
    }
    void DIT(int *a) {
        int i, *j, *k, len = 1, R, *o;
        for (i = 0; i < l; ++i, len <<= 1)
            for (j = a, o = w2; j != a + n; j += len << 1, ++o)
                for (k = j; k != j + len; ++k)
                    reduce(R = *k + k[len] - mod), k[len] = u64(*k - k[len] + mod) * *o % mod, *k = R;
    }
    inline void DNTT(int *a) {DIF(a);}
    inline void IDNTT(int *a) {
        DIT(a), std::reverse(a + 1, a + n);
        for (int i = 0; i < n; ++i) a[i] = a[i] * iv % mod;
    }
}

它为什么跑这么快?DIT 和 DIF 在干啥?预处理的原根为何和大多数人的不一样?这篇文章将为你解开这一奥秘(

先来看 init 函数 w2[1 << t] = PowerMod(31, 1 << (21 - t)); 为什么是 31

我们发现 31^{2^{23}}=1 同时它模 998244353 的阶是 2^{23} 的倍数,也就是说它在进行 NTT 时和 3^{119} 具有相似的性质,事实上,这里的确可以换为 3^{119}

平时我的写法都要预处理 21 种原根的次幂,为什么这里只用处理一种原根呢?我们将 31 改为 3^{119} 输出一下这段代码预处理的原根前 8 项,发现结果如下:

1 911660635 372528824 488723995 929031873 373294451 628914303 661054123

再来看平常写法预处理的原根:

1: 1
2: 1 911660635
4: 1 372528824 911660635 488723995
8: 1 929031873 372528824 628914303 911660635 373294451 488723995 661054123

我们发现对这一结果蝴蝶变换(二进制翻转)可以得到如下结果:

1: 1
2: 1 911660635
4: 1 911660635 372528824 488723995
8: 1 911660635 372528824 488723995 929031873 373294451 628914303 661054123

我们发现 12 的前缀,24 的前缀……

经过冷静思考,我们发现这是显然的,蝴蝶变换是 0 不动,偶数放左边,奇数放右边,分别进行少一位的蝴蝶变换,而根据 \omega_{2n}^{2i}=\omega_n^i 所以它前一半就是对 \frac{n}{2} 范围的原根做蝴蝶变换的结果。

代码在做什么也很好懂了,预处理出 g^{2^k} 放在 2^{21-k} 处(即蝴蝶变换后的结果),再递推得到其他结果(g^{2^j+2^k}=g^{2^j}\times g^{2^k},二进制翻转后也可以这样找每个为 1 的位乘上)。

这样预处理原根有什么用?等下就知道了。

我们还要知道它的基本原理:DIT/DIF。在 rushcheyo 学长《转置原理及其应用》中我们了解到 DIT(decimation in time,按时域抽取)-FFT 可以将蝴蝶变换后的系数向量转化为点值向量; DIF(decimation in frequency,按频域抽取)-FFT 可以将系数向量转化为蝴蝶变换后的点值向量,二者互为置换。

Update on 2023.01.13: 更新了一点内容,请在这篇文章查看。

我们发现可以用 DIF 实现 DFT,用 DIT 实现 IDFT 于是我们就不用进行蝴蝶变换了。

这是我写的一份朴素的 DIT/DIF-NTT:

void init_Poly() {
  for (int l = 1; l < (1 << 21); l <<= 1) {
    gw[l] = 1;
    int gn = pow(g, (Mod - 1) / (l << 1), Mod);
    for (int j = 1; j < l; ++j) {
      gw[l | j] = 1ll * gw[l | (j - 1)] * gn % Mod;
    }
  }
}
void DIT(int *A, int lim, bool flag) {
  for (int l = 1; l < lim; l <<= 1) {
    int *k = A;
    for (int i = 0; i < lim; i += (l << 1), k += (l << 1)) {
      int *x = k;
      for (int j = 0, *g = gw + l; j < l; ++j, ++x, ++g) {
        int o = 1ll * x[l] * *g % Mod;
        x[l] = (*x + Mod - o) % Mod, *x = (*x + o) % Mod;
      }
    }
  }
  int iv = pow(lim, Mod - 2, Mod);
  for (int i = 0; i < lim; ++i) A[i] = 1ll * A[i] * iv % Mod;
  std::reverse(A + 1, A + lim);
}
void DIF(int *A, int lim, bool flag) {
  for (int l = lim / 2; l >= 1; l >>= 1) {
    int *k = A;
    for (int i = 0; i < lim; i += (l << 1), k += (l << 1)) {
      int *x = k;
      for (int j = 0, *g = gw + l; j < l; ++j, ++x, ++g) {
        int o = x[l];
        x[l] = 1ll * (*x + Mod - o) * *g % Mod, *x = (*x + o) % Mod;
      }
    }
  }
}

这里的原根是最朴素的处理方式,而在进行 DIT/DIF 的时候,我们需要移动 \operatorname{O}(n\log n) 次原根,而 yhx-12243 的 DIT/DIF 只需要移动 \operatorname{O}(n) 次。

我们还发现一件神奇的事:yhx-12243 的 DIT 除了最外层 len 的枚举顺序,似乎都在做 DIF,而 DIF 除了最外层 len 的枚举顺序,似乎都在做 DIT!

这是一张 DIT-FFT 和 DIF-FFT 的示意图:

我们观察到 DIT-FFT 时如果对系数向量进行了蝴蝶变换,对 (0,4) 操作变为了对 (0,1) 操作,对 (4,6) 操作变为了对 (1,3) 操作,如果不对系数向量做蝴蝶变换并保持原先的操作呢(即仍然是对 (0,4) 操作,对 (4,6) 操作)?好像这样仍然会得到一个点值数组,这个点值数组正是蝴蝶变换后的点值数组!

原因是简单的:观察到蝴蝶变换的置换 A 有:A^{-1}=A 对于输入的系数数组做这一置换,运算过程不变,那么答案也应当也被做了该置换,于是 A\circ A=I(输入),I\circ A=A(答案)。

而原先要找的原根,也要对应的蝴蝶变换一下,这时候预处理蝴蝶变换后的原根的作用就体现出来了!

更为重要的是,对于一个 len 覆盖到的范围,所用的原根次幂是相同的(例如第一层变换中的 (0,4),(1,5),(2,6),(3,7),第二层变换中的 (0,2),(1,3)(4,6),(5,7)

以上内容可以手画一下长为 16 的 DIT-FFT 来加深理解。

于是按从大到小枚举 len 的顺序做 DIT,干的就是 DIF 的事,同理我们也可以得到按从小到大枚举 len 的顺序做 DIF,干的就是 DIT 的事,而这种做法因为只需要移动 T(n+\frac{n}{2}+\frac{n}{4}+\cdots)=\operatorname{O}(n) 次原根所以会比原先快一些。

下面进行一些可能并不靠谱的效率差异比较(以下三份代码都使用 unsigned long long 优化,即用 ull 存储中间结果减少取模):

  1. 朴素 FFT 279.439 ms,代码 2.43 KB
  2. DIT-DIF FFT 212.99 ms,代码 2.93 KB
  3. 优化 DIT-DIF FFT 192.85 ms,代码 2.94 KB

可见 DIT-DIF FFT 相较于朴素 FFT 相比,有较大优化,而优化 DIT-DIF FFT 相较于 DIT-DIF FFT 有小幅度优化,且代码不长,实现难度不大,不失为一种较好的简单 NTT 实现方式。