
- 資料結構與演算法
- DSA - 首頁
- DSA - 概述
- DSA - 環境設定
- DSA - 演算法基礎
- DSA - 漸進分析
- 資料結構
- DSA - 資料結構基礎
- DSA - 資料結構和型別
- DSA - 陣列資料結構
- 連結串列
- DSA - 連結串列資料結構
- DSA - 雙向連結串列資料結構
- DSA - 迴圈連結串列資料結構
- 棧與佇列
- DSA - 棧資料結構
- DSA - 表示式解析
- DSA - 佇列資料結構
- 搜尋演算法
- DSA - 搜尋演算法
- DSA - 線性搜尋演算法
- DSA - 二分搜尋演算法
- DSA - 插值搜尋
- DSA - 跳躍搜尋演算法
- DSA - 指數搜尋
- DSA - 斐波那契搜尋
- DSA - 子列表搜尋
- DSA - 雜湊表
- 排序演算法
- DSA - 排序演算法
- DSA - 氣泡排序演算法
- DSA - 插入排序演算法
- DSA - 選擇排序演算法
- DSA - 歸併排序演算法
- DSA - 希爾排序演算法
- DSA - 堆排序
- DSA - 桶排序演算法
- DSA - 計數排序演算法
- DSA - 基數排序演算法
- DSA - 快速排序演算法
- 圖資料結構
- DSA - 圖資料結構
- DSA - 深度優先遍歷
- DSA - 廣度優先遍歷
- DSA - 生成樹
- 樹資料結構
- DSA - 樹資料結構
- DSA - 樹的遍歷
- DSA - 二叉搜尋樹
- DSA - AVL樹
- DSA - 紅黑樹
- DSA - B樹
- DSA - B+樹
- DSA - 伸展樹
- DSA - 字典樹
- DSA - 堆資料結構
- 遞迴
- DSA - 遞迴演算法
- DSA - 使用遞迴實現漢諾塔問題
- DSA - 使用遞迴實現斐波那契數列
- 分治法
- DSA - 分治法
- DSA - 最大最小值問題
- DSA - Strassen矩陣乘法
- DSA - Karatsuba演算法
- 貪心演算法
- DSA - 貪心演算法
- DSA - 旅行商問題(貪心演算法)
- DSA - Prim最小生成樹演算法
- DSA - Kruskal最小生成樹演算法
- DSA - Dijkstra最短路徑演算法
- DSA - 地圖著色演算法
- DSA - 分數揹包問題
- DSA - 帶截止日期的作業排序
- DSA - 最優合併模式演算法
- 動態規劃
- DSA - 動態規劃
- DSA - 矩陣鏈乘法
- DSA - Floyd-Warshall演算法
- DSA - 0-1揹包問題
- DSA - 最長公共子序列演算法
- DSA - 旅行商問題(動態規劃)
- 近似演算法
- DSA - 近似演算法
- DSA - 頂點覆蓋演算法
- DSA - 集合覆蓋問題
- DSA - 旅行商問題(近似演算法)
- 隨機化演算法
- DSA - 隨機化演算法
- DSA - 隨機化快速排序演算法
- DSA - Karger最小割演算法
- DSA - Fisher-Yates洗牌演算法
- DSA有用資源
- DSA - 問答
- DSA - 快速指南
- DSA - 有用資源
- DSA - 討論
Strassen矩陣乘法
Strassen矩陣乘法是解決矩陣乘法問題的分治法。通常的矩陣乘法方法將每一行與每一列相乘以獲得乘積矩陣。這種方法的時間複雜度為O(n3),因為它需要兩個迴圈進行乘法運算。Strassen方法是為了將時間複雜度從O(n3)降低到O(nlog 7)而引入的。
樸素方法
首先,我們將討論樸素方法及其複雜度。在這裡,我們計算Z=𝑿X × Y。使用樸素方法,如果兩個矩陣(X和Y)的階數為p × q和q × r,則可以將這兩個矩陣相乘,並且結果矩陣的階數將為p × r。以下虛擬碼描述了樸素乘法:
Algorithm: Matrix-Multiplication (X, Y, Z) for i = 1 to p do for j = 1 to r do Z[i,j] := 0 for k = 1 to q do Z[i,j] := Z[i,j] + X[i,k] × Y[k,j]
複雜度
在這裡,我們假設整數運算需要O(1)時間。該演算法中有三個for迴圈,其中一個巢狀在另一箇中。因此,該演算法需要O(n3)時間來執行。
Strassen矩陣乘法演算法
在這種情況下,使用Strassen矩陣乘法演算法,可以稍微改進時間消耗。
Strassen矩陣乘法只能對n為2的冪的方陣進行。兩個矩陣的階數均為n × n。
將X、Y和Z劃分為四個(n/2)×(n/2)矩陣,如下所示:
$Z = \begin{bmatrix}I & J \\K & L \end{bmatrix}$ $X = \begin{bmatrix}A & B \\C & D \end{bmatrix}$ 和 $Y = \begin{bmatrix}E & F \\G & H \end{bmatrix}$
使用Strassen演算法計算以下內容:
$$M_{1} \: \colon= (A+C) \times (E+F)$$
$$M_{2} \: \colon= (B+D) \times (G+H)$$
$$M_{3} \: \colon= (A-D) \times (E+H)$$
$$M_{4} \: \colon= A \times (F-H)$$
$$M_{5} \: \colon= (C+D) \times (E)$$
$$M_{6} \: \colon= (A+B) \times (H)$$
$$M_{7} \: \colon= D \times (G-E)$$
然後,
$$I \: \colon= M_{2} + M_{3} - M_{6} - M_{7}$$
$$J \: \colon= M_{4} + M_{6}$$
$$K \: \colon= M_{5} + M_{7}$$
$$L \: \colon= M_{1} - M_{3} - M_{4} - M_{5}$$
分析
$$T(n)=\begin{cases}c & if\:n= 1\\7\:x\:T(\frac{n}{2})+d\:x\:n^2 & otherwise\end{cases} \:其中\: c\: 和 \:d\:是常數$$
使用此遞迴關係,我們得到 $T(n) = O(n^{log7})$
因此,Strassen矩陣乘法演算法的複雜度為 $O(n^{log7})$。
示例
讓我們看看Strassen矩陣乘法在各種程式語言中的實現:C、C++、Java、Python。
#include<stdio.h> int main(){ int z[2][2]; int i, j; int m1, m2, m3, m4 , m5, m6, m7; int x[2][2] = { {12, 34}, {22, 10} }; int y[2][2] = { {3, 4}, {2, 1} }; printf("The first matrix is: "); for(i = 0; i < 2; i++) { printf("\n"); for(j = 0; j < 2; j++) printf("%d\t", x[i][j]); } printf("\nThe second matrix is: "); for(i = 0; i < 2; i++) { printf("\n"); for(j = 0; j < 2; j++) printf("%d\t", y[i][j]); } m1= (x[0][0] + x[1][1]) * (y[0][0] + y[1][1]); m2= (x[1][0] + x[1][1]) * y[0][0]; m3= x[0][0] * (y[0][1] - y[1][1]); m4= x[1][1] * (y[1][0] - y[0][0]); m5= (x[0][0] + x[0][1]) * y[1][1]; m6= (x[1][0] - x[0][0]) * (y[0][0]+y[0][1]); m7= (x[0][1] - x[1][1]) * (y[1][0]+y[1][1]); z[0][0] = m1 + m4- m5 + m7; z[0][1] = m3 + m5; z[1][0] = m2 + m4; z[1][1] = m1 - m2 + m3 + m6; printf("\nProduct achieved using Strassen's algorithm: "); for(i = 0; i < 2 ; i++) { printf("\n"); for(j = 0; j < 2; j++) printf("%d\t", z[i][j]); } return 0; }
輸出
The first matrix is: 12 34 22 10 The second matrix is: 3 4 2 1 Product achieved using Strassen's algorithm: 104 82 86 98
#include<iostream> using namespace std; int main() { int z[2][2]; int i, j; int m1, m2, m3, m4 , m5, m6, m7; int x[2][2] = { {12, 34}, {22, 10} }; int y[2][2] = { {3, 4}, {2, 1} }; cout<<"The first matrix is: "; for(i = 0; i < 2; i++) { cout<<endl; for(j = 0; j < 2; j++) cout<<x[i][j]<<" "; } cout<<"\nThe second matrix is: "; for(i = 0;i < 2; i++){ cout<<endl; for(j = 0;j < 2; j++) cout<<y[i][j]<<" "; } m1 = (x[0][0] + x[1][1]) * (y[0][0] + y[1][1]); m2 = (x[1][0] + x[1][1]) * y[0][0]; m3 = x[0][0] * (y[0][1] - y[1][1]); m4 = x[1][1] * (y[1][0] - y[0][0]); m5 = (x[0][0] + x[0][1]) * y[1][1]; m6 = (x[1][0] - x[0][0]) * (y[0][0]+y[0][1]); m7 = (x[0][1] - x[1][1]) * (y[1][0]+y[1][1]); z[0][0] = m1 + m4- m5 + m7; z[0][1] = m3 + m5; z[1][0] = m2 + m4; z[1][1] = m1 - m2 + m3 + m6; cout<<"\nProduct achieved using Strassen's algorithm: "; for(i = 0; i < 2 ; i++) { cout<<endl; for(j = 0; j < 2; j++) cout<<z[i][j]<<" "; } return 0; }
輸出
The first matrix is: 12 34 22 10 The second matrix is: 3 4 2 1 Product achieved using Strassen's algorithm: 104 82 86 98
public class Strassens { public static void main(String[] args) { int[][] x = {{12, 34}, {22, 10}}; int[][] y = {{3, 4}, {2, 1}}; int z[][] = new int[2][2]; int m1, m2, m3, m4 , m5, m6, m7; System.out.print("The first matrix is: "); for(int i = 0; i<2; i++) { System.out.println();//new line for(int j = 0; j<2; j++) { System.out.print(x[i][j] + "\t"); } } System.out.print("\nThe second matrix is: "); for(int i = 0; i<2; i++) { System.out.println();//new line for(int j = 0; j<2; j++) { System.out.print(y[i][j] + "\t"); } } m1 = (x[0][0] + x[1][1]) * (y[0][0] + y[1][1]); m2 = (x[1][0] + x[1][1]) * y[0][0]; m3 = x[0][0] * (y[0][1] - y[1][1]); m4 = x[1][1] * (y[1][0] - y[0][0]); m5 = (x[0][0] + x[0][1]) * y[1][1]; m6 = (x[1][0] - x[0][0]) * (y[0][0]+y[0][1]); m7 = (x[0][1] - x[1][1]) * (y[1][0]+y[1][1]); z[0][0] = m1 + m4- m5 + m7; z[0][1] = m3 + m5; z[1][0] = m2 + m4; z[1][1] = m1 - m2 + m3 + m6; System.out.print("\nProduct achieved using Strassen's algorithm: "); for(int i = 0; i<2; i++) { System.out.println();//new line for(int j = 0; j<2; j++) { System.out.print(z[i][j] + "\t"); } } } }
輸出
The first matrix is: 12 34 22 10 The second matrix is: 3 4 2 1 Product achieved using Strassen's algorithm: 104 82 86 98
import numpy as np x = np.array([[12, 34], [22, 10]]) y = np.array([[3, 4], [2, 1]]) z = np.zeros((2, 2)) m1, m2, m3, m4, m5, m6, m7 = 0, 0, 0, 0, 0, 0, 0 print("The first matrix is: ") for i in range(2): print() for j in range(2): print(x[i][j], end="\t") print("\nThe second matrix is: ") for i in range(2): print() for j in range(2): print(y[i][j], end="\t") m1 = (x[0][0] + x[1][1]) * (y[0][0] + y[1][1]) m2 = (x[1][0] + x[1][1]) * y[0][0] m3 = x[0][0] * (y[0][1] - y[1][1]) m4 = x[1][1] * (y[1][0] - y[0][0]) m5 = (x[0][0] + x[0][1]) * y[1][1] m6 = (x[1][0] - x[0][0]) * (y[0][0] + y[0][1]) m7 = (x[0][1] - x[1][1]) * (y[1][0] + y[1][1]) z[0][0] = m1 + m4 - m5 + m7 z[0][1] = m3 + m5 z[1][0] = m2 + m4 z[1][1] = m1 - m2 + m3 + m6 print("\nProduct achieved using Strassen's algorithm: ") for i in range(2): print() for j in range(2): print(z[i][j], end="\t")
輸出
The first matrix is: 12 34 22 10 The second matrix is: 3 4 2 1 Product achieved using Strassen's algorithm: 104.0 82.0 86.0 98.0