Update
- [2021/11/14] 感谢 @serverkiller 指出错误并修正。
- [2021/11/14] 替换了扩展阅读第二个 link。
Introduction
Strassen 算法在 OI 上没有任何应用价值,不过了解一下理论计算机科学相关还蛮有意思的。
计算两个矩阵相乘的朴素方法是 \Theta(n^3)。而 Strassen 算法通过采用分治策略,并减少递归次数,实现在 \Theta(n^{\log 7}) 时间内完成。\log 7\approx 2.808,也就是 O(n^{2.81})。
前置知识: 矩阵基础,主定理。
其实主定理只需要知道结论即可,这里放一下简化结论:
T(n)=aT\left(\dfrac{n}{b}\right)+O(n^d),a\ge 1,b>1\\
T(n)=\begin{cases}
O(n^d)&,d>\log_{b}a\\
O(n^d\log n)&,d=\log_b a\\
O(n^{\log_b a})&,d<\log_b a
\end{cases}
Analysis
从分治开始
在介绍 Strassen 算法之前,先探讨如何用分治完成矩阵乘法。
分治策略其实非常直截了当,就是将 n\times n 的矩阵划分为四个 \dfrac{n}{2}\times \dfrac{n}{2} 的矩阵。当然,这种做法要求 n 是 2 的幂,但相关细节我们稍后探讨。
先以 2\times 2 的矩阵为例。
\begin{bmatrix}A_{1,1}&A_{1,2}\\A_{2,1}&A_{2,2}\end{bmatrix}\cdot\begin{bmatrix}B_{1,1}&B_{1,2}\\B_{2,1}&B_{2,2}\end{bmatrix}=\begin{bmatrix}C_{1,1}&C_{1,2}\\C_{2,1}&C_{2,2}\end{bmatrix}
具体写出计算矩阵 C 的等式:
C_{1,1}=A_{1,1}\cdot B_{1,1}+A_{1,2}\cdot B_{2,1}\\
C_{1,2}=A_{1,1}\cdot B_{1,2}+A_{1,2}\cdot B_{2,2}\\
C_{2,1}=A_{2,1}\cdot B_{1,1}+A_{2,2}\cdot B_{2,1}\\
C_{2,2}=A_{2,1}\cdot B_{1,2}+A_{2,2}\cdot B_{2,2}\\
对于每一条公式,相当于计算两对 \dfrac{n}{2}\times\dfrac{n}{2} 矩阵乘法,再计算一次这样的矩阵加法。
直接利用这个运算方式就可以写出递归分治策略。形象化地说,就是将矩阵分解为左上,左下,右上,右下四个子矩阵再分别进行运算。
\boxed{
\begin{array}{ll}
&\textbf {Function Multiply}(A,B)\\
1& n\gets A.\text{rows}\\
2& \text{let}\ C\ \text{be a new } n\times n\ \text{matrix}\\
3& \textbf{if } n=1\\
4& \qquad c_{1,1}\gets a_{1,1}\times b_{1,1}\\
5&\textbf{else}\ \mathrm{Partion}\ A,B,C\\
6&\qquad C_{1,1}\gets\text {Multiply}(A_{1,1},B_{1,1})+\text {Multiply}(A_{1,2},B_{2,1})\\
7&\qquad C_{1,2}\gets\text {Multiply}(A_{1,1},B_{1,2})+\text {Multiply}(A_{1,2},B_{2,2})\\
8&\qquad C_{2,1}\gets\text {Multiply}(A_{2,1},B_{1,1})+\text {Multiply}(A_{2,2},B_{2,1})\\
9&\qquad C_{2,2}\gets\text {Multiply}(A_{2,1},B_{1,2})+\text {Multiply}(A_{2,2},B_{2,2})\\
10&\textbf{return}\ C
\end{array}}
注意 \rm Partion 部分只需要对应好分解后的矩阵的下标即可。具体的对应方法可以参考 Strassen 算法伪代码。
需要说明的是,《算法导论》认为,只需通过下标计算即可对实现分解子矩阵并操作,而不用 \Theta(n^2) 拷贝子矩阵。然而在亲自动手实现代码时,避免拷贝子矩阵来进行其他操作是异常麻烦的,书中也没有给出伪代码。况且拷贝也不影响总的时间复杂度,因为矩阵加法需要不可避免的 \Theta(n^2)。但拷贝操作对常数因子影响比较大。如果有谁会实现不用拷贝的请务必把代码发给我/kk
现在来看一下这段代码的时间复杂度。
我们分解出了 8 个子问题,每个子问题规模缩小了一半,同时花了 \Theta(n^2) 时间分解出子矩阵,花了 \Theta(n^2) 进行矩阵加法。容易写出其递归式:
\begin{aligned}
T(n)&=8T\left(\dfrac{n}{2}\right)+\Theta(n^2)
\end{aligned}
运用主定理即可求解时间复杂度为 \Theta(n^3)。
没有优化啊?这就来到 Strassen 算法的另一个核心:减少递归次数。
还能再少一次
Strassen 算法在朴素分治算法的基础上,只进行了 7 次递归。当然减少递归次数的代价就是多进行了几次矩阵加法,但幸好只是常数级别。
我们先对时间复杂度进行分析。递归式为:
T(n)=7T\left(\dfrac{n}{2}\right)+\Theta(n^2)
运用主定理求解出时间复杂度为 \Theta(n^{\log_2 7}),也就是 O(n^{2.81})。
步骤上就比朴素的分治算法要麻烦一些。
- 与朴素分治算法相同,分解出左上,左下,右上,右下四个子矩阵。
- 创建 10 个 \dfrac{n}2\times\dfrac n2 的矩阵 S_i,每个 S_i 保存两个子矩阵的和或差。
- 用子矩阵和 S_i 相乘,递归地计算 7 个 \dfrac{n}2\times\dfrac n2 的矩阵 P_i。
- 通过 P_i 的不同组合进行加减,得到 C 的子矩阵。
具体地,步骤 2 中创建的 10 个矩阵 S_i 分别为:
\begin{array}{ll}
S_1&=B_{1,2}-B_{2,2}\\
S_2&=A_{1,1}+A_{1,2}\\
S_3&=A_{2,1}+A_{2,2}\\
S_4&=B_{2,1}-B_{1,1}\\
S_5&=A_{1,1}+A_{2,2}\\
S_6&=B_{1,1}+B_{2,2}\\
S_8&=A_{1,2}-A_{2,2}\\
S_8&=B_{2,1}+B_{2,2}\\
S_9&=A_{1,1}-A_{2,1}\\
S_{10}&=B_{1,1}+B_{1,2}
\end{array}
步骤 3 中需要递归计算的 7 个矩阵 P_i 分别为:
\begin{array}{ll}
P_1&=A_{1,1}\cdot S_1\\
P_2&=S_2\cdot B_{2,2}\\
P_3&=S_3\cdot B_{1,1}\\
P_4&=A_{2,2}\cdot S_4\\
P_5&=S_5\cdot S_6\\
P_6&=S_7\cdot S_8\\
P_7&=S_9\cdot S_{10}
\end{array}
到了步骤 4,计算 C 的子矩阵的方法为:
\begin{array}{ll}
C_{1,1}&=P_5+P_4-P_2+P_6\\
C_{1,2}&=P_1+P_2\\
C_{2,1}&=P_3+P_4\\
C_{2,2}&=P_5+P_1-P_3-P_7
\end{array}
这些式子为什么正确?直接代入即可验证。由于验证过程过于冗长,这里只举一例 C_{1,2}=P_1+P_2。
\begin{aligned}
P_1+P_2&=A_{1,1}\cdot S_1+S_2\cdot B_{2,2}\\
&=A_{1,1}\cdot (B_{1,2}-B_{2,2})+(A_{1,1}+A_{1,2})\cdot B_{2,2}\\
&=A_{1,1}\cdot B_{1,2}-A_{1,1}\cdot B_{2,2}+A_{1,1}\cdot B_{2,2}+A_{1,2}\cdot B_{2,2}\\
&=A_{1,1}\cdot B_{1,2}+A_{1,2}\cdot B_{2,2}\\
&=C_{1,2}
\end{aligned}
至于 Strassen 具体是如何想到构造出这些算式的,则留给我们作为无限的遐想。有兴趣可以浏览这里。
Exercises
节选自《算法导论》4.2 练习。
- 试只用三次乘法完成复数相乘 (a+b\mathrm{i})(c+d\mathrm{i})=(ac-bd)+(ad+bc)\mathrm{i}。
设
\begin{cases}\alpha=ac\\ \beta=bd\\\gamma=(a+b)(c+d)\\\end{cases}
则:
(a+b\mathrm{i})(c+d\mathrm{i})=(\alpha-\beta)+(\gamma-\alpha-\beta)\mathrm{i}
计算 \alpha,\beta,\gamma 只用了三次乘法,代价就是增加了加法次数。事实上 Gauss 早已发现了三次乘法进行复数相乘的方法,说不定 Strassen 是受到了这个启发?
- 若矩阵规模不是 2 的幂,如何应用 Strassen 算法?
用值 0 补齐到 2 的幂即可。这是实现代码时需要注意的地方。
- 已知用 k 次乘法操作完成两个 3\times 3 的矩阵相乘,那么满足在 o(n^{\log 7}) 的时间内完成 n\times n 的矩阵相乘,k 的最大值是多少?
容易列出递归式并求解:
T(n)&=kT\left(\dfrac n3\right)+O(n^2)\\
T(n)&=O(\log_3 k)\\
\log_3k&<\log_2 7\\
k_{\max}&=21
\end{aligned}
- 编写 Strassen 算法的伪代码。
凑合着看吧,这里把 \rm Partion 部分具体写了出来。
其中 A[1\sim n/2][1\sim n/2] 表示由 \forall i\in[1,n/2],\forall j\in[1,n/2],A_{i,j} 组成的子矩阵,其余类似。
\boxed{
\begin{array}{ll}
&\textbf{Function Strassen}(A, B)\\
1& n \gets A.\mathrm{rows}\\
2& \textbf{if}\ n = 1\\
3& \qquad \textbf{return}\ a[1, 1]\times b[1, 1]\\
4& \mathrm{let}\ C\ \mathrm{be\ a\ new}\ n \times n\ \mathrm{matrix}\\
5& A[1, 1] \gets A[1\sim n / 2][1\sim n / 2]\\
6& A[1, 2] \gets A[1\sim n / 2][n / 2 + 1\sim n]\\
7& A[2, 1] \gets A[n / 2 + 1\sim n][1\sim n / 2]\\
8& A[2, 2] \gets A[n / 2 + 1\sim n][n / 2 + 1\sim n]\\
9& B[1, 1] \gets B[1\sim n / 2][1\sim n / 2]\\
10& B[1, 2] \gets B[1\sim n / 2][n / 2 + 1\sim n]\\
11& B[2, 1] \gets B[n / 2 + 1\sim n][1\sim n / 2]\\
12& B[2, 2] \gets B[n / 2 + 1\sim n][n / 2 + 1\sim n]\\
13& S[1] \gets B[1, 2] - B[2, 2]\\
14& S[2] \gets A[1, 1] + A[1, 2]\\
15& S[3] \gets A[2, 1] + A[2, 2]\\
16& S[4] \gets B[2, 1] - B[1, 1]\\
17& S[5] \gets A[1, 1] + A[2, 2]\\
18& S[6] \gets B[1, 1] + B[2, 2]\\
19& S[7] \gets A[1, 2] - A[2, 2]\\
20& S[8] \gets B[2, 1] + B[2, 2]\\
21& S[9] \gets A[1, 1] - A[2, 1]\\
22& S[10] \gets B[1, 1] + B[1, 2]\\
23& P[1] \gets \mathrm{Strassen}(A[1, 1], S[1])\\
24& P[2] \gets \mathrm{Strassen}(S[2], B[2, 2])\\
25& P[3] \gets \mathrm{Strassen}(S[3], B[1, 1])\\
26& P[4] \gets \mathrm{Strassen}(A[2, 2], S[4])\\
27& P[5] \gets \mathrm{Strassen}(S[5], S[6])\\
28& P[6] \gets \mathrm{Strassen}(S[7], S[8])\\
29& P[7] \gets \mathrm{Strassen}(S[9], S[10])\\
30& C[1\sim n / 2][1\sim n / 2] \gets P[5] + P[4] - P[2] + P[6]\\
31& C[1\sim n / 2][n / 2 + 1\sim n] \gets P[1] + P[2]\\
32& C[n / 2 + 1\sim n][1\sim n / 2] \gets P[3] + P[4]\\
33& C[n / 2 + 1\sim n][n / 2 + 1\sim n] \gets P[5] + P[1] - P[3] - P[7]\\
34& \textbf{return}\ C\\
\end{array}}
至于 C++ 代码,就留作读者课后习题吧。
Notice
然而很遗憾的是,Strassen 算法由于使用了大量递归,多次创建临时矩阵,常数因子比朴素矩阵乘法大非常多,因此 OI 范围内几乎没有应用价值。朴素的矩阵乘法就足够了。
如果确实想要优化常数,可以从以下方面考虑:
- 使用指令集等卡常科技。
- 在子矩阵小到一定规模时使用朴素矩阵乘法。
- 尽可能地减少创建临时矩阵和复制次数。
据 w33z8kqrqk8zzzx33 在帖子里说道,使用指令集可以在 n=2^{10} 时用少于 0.2 s 的时间完成(朴素矩阵乘法用指令集应该也会快不少吧?)。
而实际应用中的大型矩阵乘法都依赖于硬件(cache, GPU 等)和分布式计算。
尽管 Strassen 算法看上去没有什么实际用处,《算法导论》依然指出了 Strassen 算法的一个最重要的意义:在理论研究上作出了突破性的贡献。
类比 1959 年发明的 Shell Sort(希尔排序),目前来看也没有任何应用价值,但这是计算机第一次在整数排序上突破了 \Theta(n^2) 的壁障(当时常见的还是插入排序和冒泡排序)。
同样地,是 Strassen 算法使得矩阵乘法在渐进上界上第一次快于 \Theta(n^3),并鼓舞着后人继续在这方面探索。1969 年 Strassen 发表论文的的标题为《高斯消元法并非最优》,也正揭示了该算法的意义所在。
Extension
Strassen 算法其实能推广到很多矩阵操作,在原论文也略有提及。其基本思路都是分治并减少递归次数。本文将对其粗略说明。至于 Strassen 构造的这些玄妙算式是如何想到的,则给后人留下了神秘的美感。
下面所有算法的时间复杂度分析都与前文矩阵乘法相类似,不再赘述。
矩阵求逆
一般的矩阵求逆是用 \Theta(n^3) 的高斯消元法完成的。而 Strassen 在矩阵乘法的基础上,得到了更快的算法。这也是 Strassen 论文标题的由来。
基本思路依然与矩阵乘法类似。
先将矩阵 A 按照左上,左下,右上,右下分解出子矩阵。
创建如下 7 个矩阵 P_i:
\begin{array}{ll}
P_1&=A_{1,1}^{-1}\\
P_2&=A_{2,1}\cdot P_1\\
P_3&=P_{1}\cdot A_{1,2}\\
P_4&=A_{2,1}\cdot P_3\\
P_5&=P_4-A_{2,2}\\
P_6&=P_5^{-1}\\
P_7&=P_{3}\cdot P_6\cdot P_2\\
\end{array}
然后计算出 A 的逆矩阵 A^{-1} 的 4 个子矩阵:
\begin{array}{ll}
A^{-1}_{1,1}&=P_1-P_7\\
A^{-1}_{1,2}&=P_3\cdot P_6\\
A^{-1}_{2,1}&=P_6\cdot P_2\\
A^{-1}_{2,2}&=-P_6
\end{array}
当子矩阵缩小到一定规模,我们就可以直接用高斯消元法求解来减小常数。
当然,递归求逆过程中也要顺便用 Strassen 算法求矩阵乘法。
需要说明的是,Strassen 假定了操作过程中所有矩阵都是可逆的,而对于更复杂的情况则束手无策。对于 Strassen 算法在矩阵求逆操作上的更深研究,限于篇幅请参考这里。
解线性方程组
Similar results hold for solving a system of linear equations or computing a
determinant.
Strassen 对于矩阵求逆和行列式计算的记录相当简略啊。
我们知道线性方程组可以写成矩阵的形式。
\begin{bmatrix}
a_{1,1}&a_{1,2}&a_{1,3}&\cdots& a_{1,m}\\
a_{2,1}&a_{2,2}&a_{2,3}&\cdots& a_{2,m}\\
a_{3,1}&a_{3,2}&a_{3,3}&\cdots& a_{3,m}\\
\vdots&\vdots&\vdots&\ddots&\vdots\\
a_{n,1}&a_{n,2}&a_{n,3}&\cdots& a_{n,m}\\
\end{bmatrix}
\cdot
\begin{bmatrix}
x_1\\
x_2\\
x_3\\
\vdots\\
x_n
\end{bmatrix}=
\begin{bmatrix}
y_1\\
y_2\\
y_3\\
\vdots\\
y_n
\end{bmatrix}
上式简记为 A\cdot x=B,则有 x=A^{-1}\cdot
B。
如果逆矩阵存在,则方程恰好有一解,直接用 Strassen 算法进行矩阵求逆即可。
行列式
同样把矩阵 A 分解成四个子矩阵,然后用一种简洁的计算方式求解即可:
\operatorname{det}(A)=\operatorname{det}(A_{1,1})\cdot \operatorname{det}\left(A_{2,2}-A_{2,1}\cdot A_{1,1}^{-1}\cdot A_{1,2}\right)
Reference
- 《算法导论》第三版第四章。
- Strassen 原论文 - Gaussian Elimination is not Optimal
关于矩阵乘法的其它研究,感兴趣的可以阅读以下文献:
- 用 O(n^{2.37286}) 时间完成矩阵乘法 - Link
- 对于 n\times n 的矩阵左乘上 n\times n^p(p\le 0.294) 的矩阵,可以做到 O(n^2) 的时间复杂度 - Link
- 关于 GPU 上进行矩阵乘法的效率 - Link