如何在 PyTorch 中連線張量?


我們可以使用 **torch.cat()** 和 **torch.stack()** 連線兩個或多個張量。**torch.cat()** 用於連線兩個或多個張量,而 **torch.stack()** 用於堆疊張量。我們可以在不同的維度上連線張量,例如 0 維度、-1 維度。

**torch.cat()** 和 **torch.stack()** 都用於連線張量。那麼,這兩種方法的基本區別是什麼呢?

  • **torch.cat()** 沿著現有維度連線一系列張量,因此不會改變張量的維度。

  • **torch.stack()** 沿著新維度堆疊張量,因此會增加維度。

步驟

  • 匯入所需的庫。在以下所有示例中,所需的 Python 庫為 **torch**。請確保您已安裝它。

  • 建立兩個或多個 PyTorch 張量並列印它們。

  • 使用 **torch.cat()** 或 **torch.stack()** 連線上面建立的張量。提供維度,例如 0、-1,以在特定維度上連線張量。

  • 最後,列印連線或堆疊的張量。

示例 1

# Python program to join tensors in PyTorch
# import necessary library
import torch

# create tensors
T1 = torch.Tensor([1,2,3,4])
T2 = torch.Tensor([0,3,4,1])
T3 = torch.Tensor([4,3,2,5])

# print above created tensors
print("T1:", T1)
print("T2:", T2)
print("T3:", T3)

# join (concatenate) above tensors using torch.cat()
T = torch.cat((T1,T2,T3))
# print final tensor after concatenation
print("T:",T)

輸出

執行以上 Python 3 程式碼時,將產生以下輸出

T1: tensor([1., 2., 3., 4.])
T2: tensor([0., 3., 4., 1.])
T3: tensor([4., 3., 2., 5.])
T: tensor([1., 2., 3., 4., 0., 3., 4., 1., 4., 3., 2., 5.])

示例 2

# import necessary library
import torch

# create tensors
T1 = torch.Tensor([[1,2],[3,4]])
T2 = torch.Tensor([[0,3],[4,1]])
T3 = torch.Tensor([[4,3],[2,5]])

# print above created tensors
print("T1:\n", T1)
print("T2:\n", T2)
print("T3:\n", T3)

print("join(concatenate) tensors in the 0 dimension")
T = torch.cat((T1,T2,T3), 0)
print("T:\n", T)

print("join(concatenate) tensors in the -1 dimension")
T = torch.cat((T1,T2,T3), -1)
print("T:\n", T)

輸出

執行以上 Python 3 程式碼時,將產生以下輸出

T1:
tensor([[1., 2.],
         [3., 4.]])
T2:
tensor([[0., 3.],
         [4., 1.]])
T3:
tensor([[4., 3.],
         [2., 5.]])
join(concatenate) tensors in the 0 dimension
T:
tensor([[1., 2.],
         [3., 4.],
         [0., 3.],
         [4., 1.],
         [4., 3.],
         [2., 5.]])
join(concatenate) tensors in the -1 dimension
T:
tensor([[1., 2., 0., 3., 4., 3.],
         [3., 4., 4., 1., 2., 5.]])

在以上示例中,2D 張量沿 0 和 -1 維度連線。沿 0 維度連線會增加行數,而列數保持不變。

示例 3

# Python program to join tensors in PyTorch
# import necessary library
import torch

# create tensors
T1 = torch.Tensor([1,2,3,4])
T2 = torch.Tensor([0,3,4,1])
T3 = torch.Tensor([4,3,2,5])

# print above created tensors
print("T1:", T1)
print("T2:", T2)
print("T3:", T3)

# join above tensor using "torch.stack()"
print("join(stack) tensors")
T = torch.stack((T1,T2,T3))

# print final tensor after join
print("T:\n",T)
print("join(stack) tensors in the 0 dimension")
T = torch.stack((T1,T2,T3), 0)

print("T:\n", T)
print("join(stack) tensors in the -1 dimension")
T = torch.stack((T1,T2,T3), -1)
print("T:\n", T)

輸出

執行以上 Python 3 程式碼時,將產生以下輸出

T1: tensor([1., 2., 3., 4.])
T2: tensor([0., 3., 4., 1.])
T3: tensor([4., 3., 2., 5.])
join(stack) tensors
T:
tensor([[1., 2., 3., 4.],
         [0., 3., 4., 1.],
         [4., 3., 2., 5.]])
join(stack) tensors in the 0 dimension
T:
tensor([[1., 2., 3., 4.],
         [0., 3., 4., 1.],
         [4., 3., 2., 5.]])
join(stack) tensors in the -1 dimension
T:
tensor([[1., 0., 4.],
         [2., 3., 3.],
         [3., 4., 2.],
         [4., 1., 5.]])

在以上示例中,您可以注意到 1D 張量被堆疊,最終張量為 2D 張量。

示例 4

# import necessary library
import torch

# create tensors
T1 = torch.Tensor([[1,2],[3,4]])
T2 = torch.Tensor([[0,3],[4,1]])
T3 = torch.Tensor([[4,3],[2,5]])

# print above created tensors
print("T1:\n", T1)
print("T2:\n", T2)
print("T3:\n", T3)

print("Join (stack)tensors in the 0 dimension")
T = torch.stack((T1,T2,T3), 0)
print("T:\n", T)

print("Join(stack) tensors in the -1 dimension")
T = torch.stack((T1,T2,T3), -1)
print("T:\n", T)

輸出

執行以上 Python 3 程式碼時,將產生以下輸出。

T1:
tensor([[1., 2.],
         [3., 4.]])
T2:
tensor([[0., 3.],
         [4., 1.]])
T3:
tensor([[4., 3.],
         [2., 5.]])
Join (stack)tensors in the 0 dimension
T:
tensor([[[1., 2.],
         [3., 4.]],
         [[0., 3.],
         [4., 1.]],
         [[4., 3.],
         [2., 5.]]])
Join(stack) tensors in the -1 dimension
T:
tensor([[[1., 0., 4.],
         [2., 3., 3.]],
         [[3., 4., 2.],
         [4., 1., 5.]]])

在以上示例中,您可以注意到 2D 張量被連線(堆疊)以建立 3D 張量。

更新於: 2023年9月14日

32K+ 瀏覽量

開啟您的 職業生涯

透過完成課程獲得認證

開始學習
廣告

© . All rights reserved.