如何在 PyTorch 中對張量擠壓和展開?


要擠壓一個張量,我們使用 **torch.squeeze()** 方法。它返回一個新張量,該張量包含輸入張量的所有維度,但會移除大小 1。例如,如果輸入張量的形狀為 (M ☓ 1 ☓ N ☓ 1 ☓ P),則擠壓後的張量形狀為 (M ☓ M ☓ P)。

要展開一個張量,我們使用 **torch.unsqueeze()** 方法。它返回一個新張量,在特定位置插入大小為 1 的新維度。

步驟

  • 匯入所需的庫。在以下所有 Python 示例中,所需的 Python 庫為 **torch**。確保你已經安裝它。

  • 建立一個張量並列印它。

  • 計算 **torch.squeeze(input)**。它將擠壓(移除)大小 1,並返回一個包含 **input** 張量所有其他維度的張量。

  • 計算 **torch.unsqueeze(input, dim)**。它在給定的 dim 處插入大小為 1 的新維度,並返回該張量。

  • 列印擠壓和/或展開的張量。

示例 1

# Python program to squeeze and unsqueeze a tensor
# import necessary library
import torch

# Create a tensor of all one
T = torch.ones(2,1,2) # size 2x1x2
print("Original Tensor T:\n", T )
print("Size of T:", T.size())

# Squeeze the dimension of the tensor
squeezed_T = torch.squeeze(T) # now size 2x2
print("Squeezed_T\n:", squeezed_T )
print("Size of Squeezed_T:", squeezed_T.size())

輸出

Original Tensor T:
tensor([[[1., 1.]],
         [[1., 1.]]])
Size of T: torch.Size([2, 1, 2])
Squeezed_T
: tensor([[1., 1.],
         [1., 1.]])
Size of Squeezed_T: torch.Size([2, 2])

示例 2

# Python program to squeeze and unsqueeze a tensor
# import necessary library
import torch

# create a tensor
T = torch.Tensor([1,2,3]) # size 3
print("Original Tensor T:\n", T )
print("Size of T:", T.size())

# Squeeze the tensor in dimension o or column dim
unsqueezed_T = torch.unsqueeze(T, dim = 0) # now size 1x3
print("Unsqueezed T\n:", unsqueezed_T )
print("Size of UnSqueezed T:", unsqueezed_T.size())

# Squeeze the tensor in dimension 1 or row dim
unsqueezed_T = torch.unsqueeze(T, dim = 1) # now size 3x1
print("Unsqueezed T\n:", unsqueezed_T )
print("Size of Unsqueezed T:", unsqueezed_T.size())

輸出

Original Tensor T:
   tensor([1., 2., 3.])
Size of T: torch.Size([3])
Unsqueezed T
: tensor([[1., 2., 3.]])
Size of UnSqueezed T: torch.Size([1, 3])
Unsqueezed T
: tensor([[1.],
         [2.],
         [3.]])
Size of Unsqueezed T: torch.Size([3, 1])

更新日期:2021-11-06

4K+ 次瀏覽

開啟你的 職業生涯

完成課程即可獲得認證

開始
廣告
© . All rights reserved.