矩阵乘法的 Strassen 算法

Hexarhy

2021-06-13 18:12:26

Algo. & Theory

Update

  1. [2021/11/14] 感谢 @serverkiller 指出错误并修正。
  2. [2021/11/14] 替换了扩展阅读第二个 link。

Introduction

Strassen 算法在 OI 上没有任何应用价值,不过了解一下理论计算机科学相关还蛮有意思的。

计算两个矩阵相乘的朴素方法是 \Theta(n^3)。而 Strassen 算法通过采用分治策略,并减少递归次数,实现在 \Theta(n^{\log 7}) 时间内完成。\log 7\approx 2.808,也就是 O(n^{2.81})

前置知识: 矩阵基础,主定理。

其实主定理只需要知道结论即可,这里放一下简化结论:

T(n)=aT\left(\dfrac{n}{b}\right)+O(n^d),a\ge 1,b>1\\ T(n)=\begin{cases} O(n^d)&,d>\log_{b}a\\ O(n^d\log n)&,d=\log_b a\\ O(n^{\log_b a})&,d<\log_b a \end{cases}

Analysis

从分治开始

在介绍 Strassen 算法之前,先探讨如何用分治完成矩阵乘法。

分治策略其实非常直截了当,就是将 n\times n 的矩阵划分为四个 \dfrac{n}{2}\times \dfrac{n}{2} 的矩阵。当然,这种做法要求 n2 的幂,但相关细节我们稍后探讨。

先以 2\times 2 的矩阵为例。

\begin{bmatrix}A_{1,1}&A_{1,2}\\A_{2,1}&A_{2,2}\end{bmatrix}\cdot\begin{bmatrix}B_{1,1}&B_{1,2}\\B_{2,1}&B_{2,2}\end{bmatrix}=\begin{bmatrix}C_{1,1}&C_{1,2}\\C_{2,1}&C_{2,2}\end{bmatrix}

具体写出计算矩阵 C 的等式:

C_{1,1}=A_{1,1}\cdot B_{1,1}+A_{1,2}\cdot B_{2,1}\\ C_{1,2}=A_{1,1}\cdot B_{1,2}+A_{1,2}\cdot B_{2,2}\\ C_{2,1}=A_{2,1}\cdot B_{1,1}+A_{2,2}\cdot B_{2,1}\\ C_{2,2}=A_{2,1}\cdot B_{1,2}+A_{2,2}\cdot B_{2,2}\\

对于每一条公式,相当于计算两对 \dfrac{n}{2}\times\dfrac{n}{2} 矩阵乘法,再计算一次这样的矩阵加法。

直接利用这个运算方式就可以写出递归分治策略。形象化地说,就是将矩阵分解为左上,左下,右上,右下四个子矩阵再分别进行运算。

\boxed{ \begin{array}{ll} &\textbf {Function Multiply}(A,B)\\ 1& n\gets A.\text{rows}\\ 2& \text{let}\ C\ \text{be a new } n\times n\ \text{matrix}\\ 3& \textbf{if } n=1\\ 4& \qquad c_{1,1}\gets a_{1,1}\times b_{1,1}\\ 5&\textbf{else}\ \mathrm{Partion}\ A,B,C\\ 6&\qquad C_{1,1}\gets\text {Multiply}(A_{1,1},B_{1,1})+\text {Multiply}(A_{1,2},B_{2,1})\\ 7&\qquad C_{1,2}\gets\text {Multiply}(A_{1,1},B_{1,2})+\text {Multiply}(A_{1,2},B_{2,2})\\ 8&\qquad C_{2,1}\gets\text {Multiply}(A_{2,1},B_{1,1})+\text {Multiply}(A_{2,2},B_{2,1})\\ 9&\qquad C_{2,2}\gets\text {Multiply}(A_{2,1},B_{1,2})+\text {Multiply}(A_{2,2},B_{2,2})\\ 10&\textbf{return}\ C \end{array}}

注意 \rm Partion 部分只需要对应好分解后的矩阵的下标即可。具体的对应方法可以参考 Strassen 算法伪代码。

需要说明的是,《算法导论》认为,只需通过下标计算即可对实现分解子矩阵并操作,而不用 \Theta(n^2) 拷贝子矩阵。然而在亲自动手实现代码时,避免拷贝子矩阵来进行其他操作是异常麻烦的,书中也没有给出伪代码。况且拷贝也不影响总的时间复杂度,因为矩阵加法需要不可避免的 \Theta(n^2)。但拷贝操作对常数因子影响比较大。如果有谁会实现不用拷贝的请务必把代码发给我/kk

现在来看一下这段代码的时间复杂度。

我们分解出了 8 个子问题,每个子问题规模缩小了一半,同时花了 \Theta(n^2) 时间分解出子矩阵,花了 \Theta(n^2) 进行矩阵加法。容易写出其递归式:

\begin{aligned} T(n)&=8T\left(\dfrac{n}{2}\right)+\Theta(n^2) \end{aligned}

运用主定理即可求解时间复杂度为 \Theta(n^3)

没有优化啊?这就来到 Strassen 算法的另一个核心:减少递归次数。

还能再少一次

Strassen 算法在朴素分治算法的基础上,只进行了 7 次递归。当然减少递归次数的代价就是多进行了几次矩阵加法,但幸好只是常数级别。

我们先对时间复杂度进行分析。递归式为:

T(n)=7T\left(\dfrac{n}{2}\right)+\Theta(n^2)

运用主定理求解出时间复杂度为 \Theta(n^{\log_2 7}),也就是 O(n^{2.81})

步骤上就比朴素的分治算法要麻烦一些。

  1. 与朴素分治算法相同,分解出左上,左下,右上,右下四个子矩阵。
  2. 创建 10\dfrac{n}2\times\dfrac n2 的矩阵 S_i,每个 S_i 保存两个子矩阵的和或差。
  3. 用子矩阵和 S_i 相乘,递归地计算 7\dfrac{n}2\times\dfrac n2 的矩阵 P_i
  4. 通过 P_i 的不同组合进行加减,得到 C 的子矩阵。

具体地,步骤 2 中创建的 10 个矩阵 S_i 分别为:

\begin{array}{ll} S_1&=B_{1,2}-B_{2,2}\\ S_2&=A_{1,1}+A_{1,2}\\ S_3&=A_{2,1}+A_{2,2}\\ S_4&=B_{2,1}-B_{1,1}\\ S_5&=A_{1,1}+A_{2,2}\\ S_6&=B_{1,1}+B_{2,2}\\ S_8&=A_{1,2}-A_{2,2}\\ S_8&=B_{2,1}+B_{2,2}\\ S_9&=A_{1,1}-A_{2,1}\\ S_{10}&=B_{1,1}+B_{1,2} \end{array}

步骤 3 中需要递归计算的 7 个矩阵 P_i 分别为:

\begin{array}{ll} P_1&=A_{1,1}\cdot S_1\\ P_2&=S_2\cdot B_{2,2}\\ P_3&=S_3\cdot B_{1,1}\\ P_4&=A_{2,2}\cdot S_4\\ P_5&=S_5\cdot S_6\\ P_6&=S_7\cdot S_8\\ P_7&=S_9\cdot S_{10} \end{array}

到了步骤 4,计算 C 的子矩阵的方法为:

\begin{array}{ll} C_{1,1}&=P_5+P_4-P_2+P_6\\ C_{1,2}&=P_1+P_2\\ C_{2,1}&=P_3+P_4\\ C_{2,2}&=P_5+P_1-P_3-P_7 \end{array}

这些式子为什么正确?直接代入即可验证。由于验证过程过于冗长,这里只举一例 C_{1,2}=P_1+P_2

\begin{aligned} P_1+P_2&=A_{1,1}\cdot S_1+S_2\cdot B_{2,2}\\ &=A_{1,1}\cdot (B_{1,2}-B_{2,2})+(A_{1,1}+A_{1,2})\cdot B_{2,2}\\ &=A_{1,1}\cdot B_{1,2}-A_{1,1}\cdot B_{2,2}+A_{1,1}\cdot B_{2,2}+A_{1,2}\cdot B_{2,2}\\ &=A_{1,1}\cdot B_{1,2}+A_{1,2}\cdot B_{2,2}\\ &=C_{1,2} \end{aligned}

至于 Strassen 具体是如何想到构造出这些算式的,则留给我们作为无限的遐想。有兴趣可以浏览这里。

Exercises

节选自《算法导论》4.2 练习。

  1. 试只用三次乘法完成复数相乘 (a+b\mathrm{i})(c+d\mathrm{i})=(ac-bd)+(ad+bc)\mathrm{i}

\begin{cases}\alpha=ac\\ \beta=bd\\\gamma=(a+b)(c+d)\\\end{cases}

则:

(a+b\mathrm{i})(c+d\mathrm{i})=(\alpha-\beta)+(\gamma-\alpha-\beta)\mathrm{i}

计算 \alpha,\beta,\gamma 只用了三次乘法,代价就是增加了加法次数。事实上 Gauss 早已发现了三次乘法进行复数相乘的方法,说不定 Strassen 是受到了这个启发?

  1. 若矩阵规模不是 2 的幂,如何应用 Strassen 算法?

用值 0 补齐到 2 的幂即可。这是实现代码时需要注意的地方。

  1. 已知用 k 次乘法操作完成两个 3\times 3 的矩阵相乘,那么满足在 o(n^{\log 7}) 的时间内完成 n\times n 的矩阵相乘,k 的最大值是多少?

容易列出递归式并求解:

T(n)&=kT\left(\dfrac n3\right)+O(n^2)\\ T(n)&=O(\log_3 k)\\ \log_3k&<\log_2 7\\ k_{\max}&=21 \end{aligned}
  1. 编写 Strassen 算法的伪代码。

凑合着看吧,这里把 \rm Partion 部分具体写了出来。

其中 A[1\sim n/2][1\sim n/2] 表示由 \forall i\in[1,n/2],\forall j\in[1,n/2],A_{i,j} 组成的子矩阵,其余类似。

\boxed{ \begin{array}{ll} &\textbf{Function Strassen}(A, B)\\ 1& n \gets A.\mathrm{rows}\\ 2& \textbf{if}\ n = 1\\ 3& \qquad \textbf{return}\ a[1, 1]\times b[1, 1]\\ 4& \mathrm{let}\ C\ \mathrm{be\ a\ new}\ n \times n\ \mathrm{matrix}\\ 5& A[1, 1] \gets A[1\sim n / 2][1\sim n / 2]\\ 6& A[1, 2] \gets A[1\sim n / 2][n / 2 + 1\sim n]\\ 7& A[2, 1] \gets A[n / 2 + 1\sim n][1\sim n / 2]\\ 8& A[2, 2] \gets A[n / 2 + 1\sim n][n / 2 + 1\sim n]\\ 9& B[1, 1] \gets B[1\sim n / 2][1\sim n / 2]\\ 10& B[1, 2] \gets B[1\sim n / 2][n / 2 + 1\sim n]\\ 11& B[2, 1] \gets B[n / 2 + 1\sim n][1\sim n / 2]\\ 12& B[2, 2] \gets B[n / 2 + 1\sim n][n / 2 + 1\sim n]\\ 13& S[1] \gets B[1, 2] - B[2, 2]\\ 14& S[2] \gets A[1, 1] + A[1, 2]\\ 15& S[3] \gets A[2, 1] + A[2, 2]\\ 16& S[4] \gets B[2, 1] - B[1, 1]\\ 17& S[5] \gets A[1, 1] + A[2, 2]\\ 18& S[6] \gets B[1, 1] + B[2, 2]\\ 19& S[7] \gets A[1, 2] - A[2, 2]\\ 20& S[8] \gets B[2, 1] + B[2, 2]\\ 21& S[9] \gets A[1, 1] - A[2, 1]\\ 22& S[10] \gets B[1, 1] + B[1, 2]\\ 23& P[1] \gets \mathrm{Strassen}(A[1, 1], S[1])\\ 24& P[2] \gets \mathrm{Strassen}(S[2], B[2, 2])\\ 25& P[3] \gets \mathrm{Strassen}(S[3], B[1, 1])\\ 26& P[4] \gets \mathrm{Strassen}(A[2, 2], S[4])\\ 27& P[5] \gets \mathrm{Strassen}(S[5], S[6])\\ 28& P[6] \gets \mathrm{Strassen}(S[7], S[8])\\ 29& P[7] \gets \mathrm{Strassen}(S[9], S[10])\\ 30& C[1\sim n / 2][1\sim n / 2] \gets P[5] + P[4] - P[2] + P[6]\\ 31& C[1\sim n / 2][n / 2 + 1\sim n] \gets P[1] + P[2]\\ 32& C[n / 2 + 1\sim n][1\sim n / 2] \gets P[3] + P[4]\\ 33& C[n / 2 + 1\sim n][n / 2 + 1\sim n] \gets P[5] + P[1] - P[3] - P[7]\\ 34& \textbf{return}\ C\\ \end{array}}

至于 C++ 代码,就留作读者课后习题吧。

Notice

然而很遗憾的是,Strassen 算法由于使用了大量递归,多次创建临时矩阵,常数因子比朴素矩阵乘法大非常多,因此 OI 范围内几乎没有应用价值。朴素的矩阵乘法就足够了。

如果确实想要优化常数,可以从以下方面考虑:

据 w33z8kqrqk8zzzx33 在帖子里说道,使用指令集可以在 n=2^{10} 时用少于 0.2 s 的时间完成(朴素矩阵乘法用指令集应该也会快不少吧?)。

而实际应用中的大型矩阵乘法都依赖于硬件(cache, GPU 等)和分布式计算。

尽管 Strassen 算法看上去没有什么实际用处,《算法导论》依然指出了 Strassen 算法的一个最重要的意义:在理论研究上作出了突破性的贡献。

类比 1959 年发明的 Shell Sort(希尔排序),目前来看也没有任何应用价值,但这是计算机第一次在整数排序上突破了 \Theta(n^2) 的壁障(当时常见的还是插入排序和冒泡排序)。

同样地,是 Strassen 算法使得矩阵乘法在渐进上界上第一次快于 \Theta(n^3),并鼓舞着后人继续在这方面探索。1969 年 Strassen 发表论文的的标题为《高斯消元法并非最优》,也正揭示了该算法的意义所在。

Extension

Strassen 算法其实能推广到很多矩阵操作,在原论文也略有提及。其基本思路都是分治并减少递归次数。本文将对其粗略说明。至于 Strassen 构造的这些玄妙算式是如何想到的,则给后人留下了神秘的美感。

下面所有算法的时间复杂度分析都与前文矩阵乘法相类似,不再赘述。

矩阵求逆

一般的矩阵求逆是用 \Theta(n^3) 的高斯消元法完成的。而 Strassen 在矩阵乘法的基础上,得到了更快的算法。这也是 Strassen 论文标题的由来。

基本思路依然与矩阵乘法类似。

先将矩阵 A 按照左上,左下,右上,右下分解出子矩阵。

创建如下 7 个矩阵 P_i

\begin{array}{ll} P_1&=A_{1,1}^{-1}\\ P_2&=A_{2,1}\cdot P_1\\ P_3&=P_{1}\cdot A_{1,2}\\ P_4&=A_{2,1}\cdot P_3\\ P_5&=P_4-A_{2,2}\\ P_6&=P_5^{-1}\\ P_7&=P_{3}\cdot P_6\cdot P_2\\ \end{array}

然后计算出 A 的逆矩阵 A^{-1}4 个子矩阵:

\begin{array}{ll} A^{-1}_{1,1}&=P_1-P_7\\ A^{-1}_{1,2}&=P_3\cdot P_6\\ A^{-1}_{2,1}&=P_6\cdot P_2\\ A^{-1}_{2,2}&=-P_6 \end{array}

当子矩阵缩小到一定规模,我们就可以直接用高斯消元法求解来减小常数。

当然,递归求逆过程中也要顺便用 Strassen 算法求矩阵乘法。

需要说明的是,Strassen 假定了操作过程中所有矩阵都是可逆的,而对于更复杂的情况则束手无策。对于 Strassen 算法在矩阵求逆操作上的更深研究,限于篇幅请参考这里。

解线性方程组

Similar results hold for solving a system of linear equations or computing a determinant.

Strassen 对于矩阵求逆和行列式计算的记录相当简略啊。

我们知道线性方程组可以写成矩阵的形式。

\begin{bmatrix} a_{1,1}&a_{1,2}&a_{1,3}&\cdots& a_{1,m}\\ a_{2,1}&a_{2,2}&a_{2,3}&\cdots& a_{2,m}\\ a_{3,1}&a_{3,2}&a_{3,3}&\cdots& a_{3,m}\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ a_{n,1}&a_{n,2}&a_{n,3}&\cdots& a_{n,m}\\ \end{bmatrix} \cdot \begin{bmatrix} x_1\\ x_2\\ x_3\\ \vdots\\ x_n \end{bmatrix}= \begin{bmatrix} y_1\\ y_2\\ y_3\\ \vdots\\ y_n \end{bmatrix}

上式简记为 A\cdot x=B,则有 x=A^{-1}\cdot B

如果逆矩阵存在,则方程恰好有一解,直接用 Strassen 算法进行矩阵求逆即可。

行列式

同样把矩阵 A 分解成四个子矩阵,然后用一种简洁的计算方式求解即可:

\operatorname{det}(A)=\operatorname{det}(A_{1,1})\cdot \operatorname{det}\left(A_{2,2}-A_{2,1}\cdot A_{1,1}^{-1}\cdot A_{1,2}\right)

Reference

关于矩阵乘法的其它研究,感兴趣的可以阅读以下文献:

  1. O(n^{2.37286}) 时间完成矩阵乘法 - Link
  2. 对于 n\times n 的矩阵左乘上 n\times n^p(p\le 0.294) 的矩阵,可以做到 O(n^2) 的时间复杂度 - Link
  3. 关于 GPU 上进行矩阵乘法的效率 - Link