PyTorch – torch.linalg.solve() 方法


為了解決具有唯一解的線性方程組,我們可以使用 **torch.linalg.solve()** 方法。此方法接受兩個引數:

  • 首先,係數矩陣 **A**,以及

  • 其次,右側張量 **b**。

其中 **A** 是一個方陣,b 是一個向量。如果 A 可逆,則解是唯一的。我們可以求解多個線性方程組。在這種情況下,A 是一批方陣,b 是一批向量。

語法

torch.linalg.solve(A, b)

引數

  • **A** – 方陣或方陣批次。它是線性方程組的係數矩陣。

  • **b** – 向量或向量批次。它是線性系統的右側張量。

它返回線性方程組解的張量。

**注意** – 此方法假設係數矩陣 A 是可逆的。如果它不可逆,則會引發執行時錯誤。

步驟

我們可以使用以下步驟來解決線性方程組。

  • 匯入所需的庫。在以下所有示例中,所需的 Python 庫為 **torch**。確保您已安裝它。

import torch
  • 為給定的線性方程組定義係數矩陣和右側張量。

A = torch.tensor([[2., 3.],[1., -2.]])
b = torch.tensor([3., 0.])
  • 使用 torch.linalg.solve(A,b) 計算唯一解。係數矩陣 A 必須可逆。

X = torch.linalg.solve(A, b)
  • 顯示解決方案。

print("Solution:
", X)
  • 檢查計算出的解是否正確。

print(torch.allclose(A @ X, b))
# True for correct solution

示例 1

請檢視以下示例:

# import required library
import torch

'''
Let's suppose our square system of linear equations is:
2x + 3y = 3
x - 2y = 0
'''

print("Linear equation:")
print("2x + 3y = 3")
print("x - 2y = 0")

# define the coefficient matrix A
A = torch.tensor([[2., 3.],[1., -2.]])
# define right hand side tensor b
b = torch.tensor([3., 0.])

# Solve the linear equation
X = torch.linalg.solve(A, b)

# print the solution of above linear equation
print("Solution:
", X) # check above solution to be true print(torch.allclose(A @ X, b))

輸出

它將產生以下輸出:

Linear equation:
2x + 3y = 3
x - 2y = 0
Solution:
   tensor([0.8571, 0.4286])
True

示例 2

讓我們再舉一個例子:

# import required library
import torch

# define the coefficient matrix A for a 3x3
# square system of linear equations
A = torch.randn(3,3)

# define right hand side tensor b
b = torch.randn(3)

# Solve the linear equation
X = torch.linalg.solve(A, b)

# print the solution of above linear equation
print("Solution:
", X) # check above solution to be true print(torch.allclose(A @ X, b))

輸出

它將產生以下輸出:

Solution:
   tensor([-0.2867, -0.9850, 0.9938])
True

更新於: 2022年1月7日

693 次瀏覽

啟動您的 職業生涯

透過完成課程獲得認證

開始學習
廣告