如何在 PyTorch 中縮小張量?
torch.narrow() 方法用於對 PyTorch 張量執行縮小操作。它返回一個新的張量,它是原始輸入張量的縮小版本。
例如,一個 [4, 3] 的張量可以縮小到 [2, 3] 或 [4, 2] 大小的張量。我們可以一次沿著單個維度縮小張量。在這裡,我們不能將兩個維度都縮小到 [2, 2] 的大小。我們也可以使用 Tensor.narrow() 來縮小張量。
語法
torch.narrow(input, dim, start, length) Tensor.narrow(dim, start, length)
引數
input – 要縮小的 PyTorch 張量。
dim – 要沿其縮小原始張量 input 的維度。
Start – 開始維度。
Length – 從開始維度到結束維度的長度。
步驟
匯入 torch 庫。確保你已經安裝了它。
import torch
建立一個 PyTorch 張量並列印張量及其大小。
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print("Tensor:
", t)
print("Size of tensor:", t.size()) # size 3x3計算 torch.narrow(input, dim, start, length) 並將值賦給一個變數。
t1 = torch.narrow(t, 0, 1, 2)
縮小後,列印結果張量及其大小。
print("Tensor after Narrowing:
", t2)
print("Size after Narrowing:", t2.size())示例 1
在下面的 Python 程式碼中,輸入張量大小為 [3, 3]。我們使用 dim = 0,start = 1 和 length = 2 沿維度 0 縮小張量。它返回一個維度為 [2, 3] 的新張量。
請注意,新張量沿維度 0 縮小,並且沿維度 0 的長度更改為 2。
# import the library
import torch
# create a tensor
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# print the created tensor
print("Tensor:
", t)
print("Size of Tensor:", t.size())
# Narrow-down the tensor in dimension 0
t1 = torch.narrow(t, 0, 1, 2)
print("Tensor after Narrowing:
", t1)
print("Size after Narrowing:", t1.size())
# Narrow down the tensor in dimension 1
t2 = torch.narrow(t, 1, 1, 2)
print("Tensor after Narrowing:
", t2)
print("Size after Narrowing:", t2.size())輸出
Tensor:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
Size of Tensor: torch.Size([3, 3])
Tensor after Narrowing:
tensor([[4, 5, 6],
[7, 8, 9]])
Size after Narrowing: torch.Size([2, 3])
Tensor after Narrowing:
tensor([[2, 3],
[5, 6],
[8, 9]])
Size after Narrowing: torch.Size([3, 2])示例 2
以下程式演示瞭如何使用 Tensor.narrow() 實現縮小操作。
# import required library
import torch
# create a tensor
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
# print the above created tensor
print("Tensor:
", t)
print("Size of Tensor:", t.size())
# Narrow-down the tensor in dimension 0
t1 = t.narrow(0, 1, 2)
print("Tensor after Narrowing:
", t1)
print("Size after Narrowing:", t1.size())
# Narrow down the tensor in dimension 1
t2 = t.narrow(1, 0, 2)
print("Tensor after Narrowing:
", t2)
print("Size after Narrowing:", t2.size())輸出
Tensor:
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
Size of Tensor: torch.Size([4, 3])
Tensor after Narrowing:
tensor([[4, 5, 6],
[7, 8, 9]])
Size after Narrowing: torch.Size([2, 3])
Tensor after Narrowing:
tensor([[ 1, 2],
[ 4, 5],
[ 7, 8],
[10, 11]])
Size after Narrowing: torch.Size([4, 2])
廣告
資料結構
網路
關係資料庫管理系統 (RDBMS)
作業系統
Java
iOS
HTML
CSS
Android
Python
C 程式設計
C++
C#
MongoDB
MySQL
Javascript
PHP