Strassen矩陣乘法



Strassen矩陣乘法是解決矩陣乘法問題的分治法。通常的矩陣乘法方法將每一行與每一列相乘以獲得乘積矩陣。這種方法的時間複雜度為O(n3),因為它需要兩個迴圈進行乘法運算。Strassen方法是為了將時間複雜度從O(n3)降低到O(nlog 7)而引入的。

樸素方法

首先,我們將討論樸素方法及其複雜度。在這裡,我們計算Z=𝑿X × Y。使用樸素方法,如果兩個矩陣(XY)的階數為p × qq × r,則可以將這兩個矩陣相乘,並且結果矩陣的階數將為p × r。以下虛擬碼描述了樸素乘法:

Algorithm: Matrix-Multiplication (X, Y, Z) 
for i = 1 to p do 
   for j = 1 to r do 
      Z[i,j] := 0 
      for k = 1 to q do 
         Z[i,j] := Z[i,j] + X[i,k] × Y[k,j] 

複雜度

在這裡,我們假設整數運算需要O(1)時間。該演算法中有三個for迴圈,其中一個巢狀在另一箇中。因此,該演算法需要O(n3)時間來執行。

Strassen矩陣乘法演算法

在這種情況下,使用Strassen矩陣乘法演算法,可以稍微改進時間消耗。

Strassen矩陣乘法只能對n2的冪方陣進行。兩個矩陣的階數均為n × n

XYZ劃分為四個(n/2)×(n/2)矩陣,如下所示:

$Z = \begin{bmatrix}I & J \\K & L \end{bmatrix}$ $X = \begin{bmatrix}A & B \\C & D \end{bmatrix}$$Y = \begin{bmatrix}E & F \\G & H \end{bmatrix}$

使用Strassen演算法計算以下內容:

$$M_{1} \: \colon= (A+C) \times (E+F)$$

$$M_{2} \: \colon= (B+D) \times (G+H)$$

$$M_{3} \: \colon= (A-D) \times (E+H)$$

$$M_{4} \: \colon= A \times (F-H)$$

$$M_{5} \: \colon= (C+D) \times (E)$$

$$M_{6} \: \colon= (A+B) \times (H)$$

$$M_{7} \: \colon= D \times (G-E)$$

然後,

$$I \: \colon= M_{2} + M_{3} - M_{6} - M_{7}$$

$$J \: \colon= M_{4} + M_{6}$$

$$K \: \colon= M_{5} + M_{7}$$

$$L \: \colon= M_{1} - M_{3} - M_{4} - M_{5}$$

分析

$$T(n)=\begin{cases}c & if\:n= 1\\7\:x\:T(\frac{n}{2})+d\:x\:n^2 & otherwise\end{cases} \:其中\: c\: 和 \:d\:是常數$$

使用此遞迴關係,我們得到 $T(n) = O(n^{log7})$

因此,Strassen矩陣乘法演算法的複雜度為 $O(n^{log7})$。

示例

讓我們看看Strassen矩陣乘法在各種程式語言中的實現:C、C++、Java、Python。

#include<stdio.h>
int main(){
   int z[2][2];
   int i, j;
   int m1, m2, m3, m4 , m5, m6, m7;
   int x[2][2] = {
       {12, 34}, 
       {22, 10}
       };
   int y[2][2] = {
       {3, 4}, 
       {2, 1}
   };
   printf("The first matrix is: ");
   for(i = 0; i < 2; i++) {
      printf("\n");
      for(j = 0; j < 2; j++)
         printf("%d\t", x[i][j]);
   }
   printf("\nThe second matrix is: ");
   for(i = 0; i < 2; i++) {
      printf("\n");
      for(j = 0; j < 2; j++)
         printf("%d\t", y[i][j]);
   }
   m1= (x[0][0] + x[1][1]) * (y[0][0] + y[1][1]);
   m2= (x[1][0] + x[1][1]) * y[0][0];
   m3= x[0][0] * (y[0][1] - y[1][1]);
   m4= x[1][1] * (y[1][0] - y[0][0]);
   m5= (x[0][0] + x[0][1]) * y[1][1];
   m6= (x[1][0] - x[0][0]) * (y[0][0]+y[0][1]);
   m7= (x[0][1] - x[1][1]) * (y[1][0]+y[1][1]);
   z[0][0] = m1 + m4- m5 + m7;
   z[0][1] = m3 + m5;
   z[1][0] = m2 + m4;
   z[1][1] = m1 - m2 + m3 + m6;
   printf("\nProduct achieved using Strassen's algorithm: ");
   for(i = 0; i < 2 ; i++) {
      printf("\n");
      for(j = 0; j < 2; j++)
         printf("%d\t", z[i][j]);
   }
   return 0;
}

輸出

The first matrix is: 
12	34	
22	10	
The second matrix is: 
3	4	
2	1	
Product achieved using Strassen's algorithm: 
104	82	
86	98
#include<iostream>
using namespace std;
int main() {
   int z[2][2];
   int i, j;
   int m1, m2, m3, m4 , m5, m6, m7;
      int x[2][2] = {
         {12, 34}, 
         {22, 10}
      };
   int y[2][2] = {
      {3, 4}, 
      {2, 1}
   };
   cout<<"The first matrix is: ";
   for(i = 0; i < 2; i++) {
      cout<<endl;
      for(j = 0; j < 2; j++)
         cout<<x[i][j]<<" ";
   }
   cout<<"\nThe second matrix is: ";
   for(i = 0;i < 2; i++){
      cout<<endl;
      for(j = 0;j < 2; j++)
         cout<<y[i][j]<<" ";
   }

   m1 = (x[0][0] + x[1][1]) * (y[0][0] + y[1][1]);
   m2 = (x[1][0] + x[1][1]) * y[0][0];
   m3 = x[0][0] * (y[0][1] - y[1][1]);
   m4 = x[1][1] * (y[1][0] - y[0][0]);
   m5 = (x[0][0] + x[0][1]) * y[1][1];
   m6 = (x[1][0] - x[0][0]) * (y[0][0]+y[0][1]);
   m7 = (x[0][1] - x[1][1]) * (y[1][0]+y[1][1]);

   z[0][0] = m1 + m4- m5 + m7;
   z[0][1] = m3 + m5;
   z[1][0] = m2 + m4;
   z[1][1] = m1 - m2 + m3 + m6;

   cout<<"\nProduct achieved using Strassen's algorithm: ";
   for(i = 0; i < 2 ; i++) {
      cout<<endl;
      for(j = 0; j < 2; j++)
         cout<<z[i][j]<<" ";
   }
   return 0;
}

輸出

The first matrix is: 
12 34 
22 10 
The second matrix is: 
3 4 
2 1 
Product achieved using Strassen's algorithm: 
104 82 
86 98
public class Strassens {
   public static void main(String[] args) {
      int[][] x = {{12, 34}, {22, 10}};
      int[][] y = {{3, 4}, {2, 1}};
      int z[][] = new int[2][2];
      int m1, m2, m3, m4 , m5, m6, m7;
      System.out.print("The first matrix is: ");
      for(int i = 0; i<2; i++) {
         System.out.println();//new line
         for(int j = 0; j<2; j++) {
            System.out.print(x[i][j] + "\t");
         }
      }
      System.out.print("\nThe second matrix is: ");
      for(int i = 0; i<2; i++) {
         System.out.println();//new line
         for(int j = 0; j<2; j++) {
            System.out.print(y[i][j] + "\t");
         }
      }
      m1 = (x[0][0] + x[1][1]) * (y[0][0] + y[1][1]);
      m2 = (x[1][0] + x[1][1]) * y[0][0];
      m3 = x[0][0] * (y[0][1] - y[1][1]);
      m4 = x[1][1] * (y[1][0] - y[0][0]);
      m5 = (x[0][0] + x[0][1]) * y[1][1];
      m6 = (x[1][0] - x[0][0]) * (y[0][0]+y[0][1]);
      m7 = (x[0][1] - x[1][1]) * (y[1][0]+y[1][1]);
      z[0][0] = m1 + m4- m5 + m7;
      z[0][1] = m3 + m5;
      z[1][0] = m2 + m4;
      z[1][1] = m1 - m2 + m3 + m6;
      System.out.print("\nProduct achieved using Strassen's algorithm: ");
      for(int i = 0; i<2; i++) {
         System.out.println();//new line
         for(int j = 0; j<2; j++) {
            System.out.print(z[i][j] + "\t");
         }
      }
   }
}

輸出

The first matrix is: 
12	34	
22	10	
The second matrix is: 
3	4	
2	1	
Product achieved using Strassen's algorithm: 
104	82	
86	98	
import numpy as np
x = np.array([[12, 34], [22, 10]])
y = np.array([[3, 4], [2, 1]])
z = np.zeros((2, 2))
m1, m2, m3, m4, m5, m6, m7 = 0, 0, 0, 0, 0, 0, 0
print("The first matrix is: ")
for i in range(2):
    print()
    for j in range(2):
        print(x[i][j], end="\t")
print("\nThe second matrix is: ")
for i in range(2):
    print()
    for j in range(2):
        print(y[i][j], end="\t")
m1 = (x[0][0] + x[1][1]) * (y[0][0] + y[1][1])
m2 = (x[1][0] + x[1][1]) * y[0][0]
m3 = x[0][0] * (y[0][1] - y[1][1])
m4 = x[1][1] * (y[1][0] - y[0][0])
m5 = (x[0][0] + x[0][1]) * y[1][1]
m6 = (x[1][0] - x[0][0]) * (y[0][0] + y[0][1])
m7 = (x[0][1] - x[1][1]) * (y[1][0] + y[1][1])

z[0][0] = m1 + m4 - m5 + m7
z[0][1] = m3 + m5
z[1][0] = m2 + m4
z[1][1] = m1 - m2 + m3 + m6

print("\nProduct achieved using Strassen's algorithm: ")
for i in range(2):
    print()
    for j in range(2):
        print(z[i][j], end="\t")

輸出

The first matrix is: 

12	34	
22	10	
The second matrix is: 

3	4	
2	1	
Product achieved using Strassen's algorithm: 

104.0	82.0	
86.0	98.0
廣告