PyTorch – 如何計算矩陣的QR分解?


**torch.linalg.qr()** 計算矩陣或矩陣批次的 QR 分解。它接受浮點型、雙精度型、複數浮點型和複數雙精度型資料的矩陣和矩陣批次。

它返回一個命名元組 **(Q, R)。Q** 在矩陣為實數值時是正交的,在矩陣為複數值時是酉的。R 是一個上三角矩陣。

語法

(Q, R) = torch.linalg.qr(mat, mode='reduced')

引數

  • **Mat** – 方陣或方陣批次。

  • **mode** – 它決定 QR 分解的模式。它設定為三種模式之一:**'reduced'**、**'complete'** 和 **'r'**。預設為 'reduced'。這是一個可選引數。

步驟

  • 匯入所需的庫。在以下所有示例中,所需的 Python 庫是 **torch**。確保您已安裝它。

import torch
  • 建立一個矩陣或矩陣批次。這裡我們定義一個大小為 [3, 2] 的矩陣(一個 2D torch 張量)。

mat = torch.tensor([[1.,12.],[14.,5.],[17.,-8.]])
  • 使用 **torch.linalg.qr(mat)** 計算輸入矩陣或矩陣批次的 QR 分解。這裡 mat 是輸入矩陣。

Q, R = torch.linalg.qr(A)
  • 顯示 Q 和 R。

print("Q:
", Q) print("R:
", R)

示例 1

在這個 Python 程式中,我們計算矩陣的 QR 分解。我們沒有給出 mode 引數。它預設設定為 '**reduced**'。

# import necessary libraries
import torch

# create a matrix
mat = torch.tensor([[1.,12.],[14.,5.],[17.,-8.]])
print("Matrix:
", mat) # compute QR decomposition Q, R = torch.linalg.qr(mat) # print Q and S matrices print("Q:
",Q) print("R:
",R)

輸出

它將產生以下輸出:

Matrix:
   tensor([[ 1., 12.],
      [14., 5.],
      [17., -8.]])
Q:
   tensor([[-0.0454, 0.8038],
      [-0.6351, 0.4351],
      [-0.7711, -0.4056]])
R:
   tensor([[-22.0454, 2.4495],
      [ 0.0000, 15.0665]])

示例 2

在這個 Python 程式中,我們計算矩陣的 QR 分解。我們將 mode 設定為 'r'。

# import necessary libraries
import torch

# create a matrix
mat = torch.tensor([[1.,12.],[14.,5.],[17.,-8.]])
print("Matrix:
", mat) # compute QR decomposition Q, R = torch.linalg.qr(mat, mode = 'r') # print Q and S matrices print("Q:
",Q) print("R:
",R)

輸出

它將產生以下輸出:

Matrix:
   tensor([[ 1., 12.],
      [14., 5.],
      [17., -8.]])
Q:
   tensor([])
R:
   tensor([[-22.0454, 2.4495],
      [ 0.0000, 15.0665]])

示例 3

在這個 Python3 程式中,我們計算矩陣的 QR 分解。我們將 mode 設定為 'complete'。

# import necessary libraries
import torch

# create a matrix
mat = torch.tensor([[1.,12.],[14.,5.],[17.,-8.]])
print("Matrix:
", mat) # compute QR decomposition Q, R = torch.linalg.qr(mat, mode = 'complete') # print Q and S matrices print("Q:
", Q) print("R:
", R)

輸出

它將產生以下輸出:

Matrix:
   tensor([[ 1., 12.],
      [14., 5.],
      [17., -8.]])
Q:
   tensor([[-0.0454, 0.8038, 0.5931],
      [-0.6351, 0.4351, -0.6383],
      [-0.7711, -0.4056, 0.4907]])
R:
   tensor([[-22.0454, 2.4495],
      [ 0.0000, 15.0665],
      [ 0.0000, 0.0000]])

更新於:2022年1月7日

287 次瀏覽

啟動您的職業生涯

完成課程獲得認證

開始學習
廣告
© . All rights reserved.