如何在 PyTorch 中透過重塑輸入張量來展平它?
可以使用 **torch.flatten()** 方法透過重塑將張量展平為一維張量。此方法支援實值和復值輸入張量。它以 torch 張量作為輸入,並返回展平為一維的 torch 張量。
它有兩個可選引數,**start_dim** 和 **end_dim**。如果傳遞了這些引數,則僅展平從 start_dim 開始到 end_dim 結束的那些維度。
輸入張量中元素的順序不會改變。此函式可能會返回原始物件、檢視或副本。在以下示例中,我們涵蓋了使用和不使用 **start_dim** 和 **end_dim** 展平張量的所有方面。
語法
torch.flatten(input, star_dim=0, end_dim=-1)
引數
**input** - 要展平的 torch 張量。
**start_dim** - 要展平的第一個維度。這是一個可選引數。預設設定為 0。
**end_dim** - 要展平的最後一個維度。這是一個可選引數。預設設定為 -1。
步驟
匯入所需的庫。在以下所有示例中,所需的 Python 庫為 **torch**。確保您已安裝它。
import torch
建立一個 PyTorch 張量並列印該張量。
t = torch.tensor([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]])
print("Tensor:
", t)使用上面定義的語法展平上述張量,並可選地將值賦給一個新變數。
flatten_t = torch.flatten(t, start_dim=0, end_dim=1)
列印展平後的張量。
print("Flattened Tensor:
", flatten_t)示例 1
在此程式中,我們將張量展平為一維張量。我們還使用 **start_dim** 展平張量。
Import the required library
import torch
# define a torch tensor
t = torch.tensor([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]])
print("Tensor:
", t)
print("Size of Tensor:", t.size())
# flatten the above tensor using start_dims
flatten_t = torch.flatten(t)
flatten_t0 = torch.flatten(t, start_dim=0)
flatten_t1 = torch.flatten(t, start_dim=1)
flatten_t2 = torch.flatten(t, start_dim=2)
# print the flatten tensors
print("Flatten tensor:
", flatten_t)
print("Flatten tensor (start_dim=0):
", flatten_t0)
print("Flatten tensor (start_dim=1):
", flatten_t1)
print("Flatten tensor (start_dim=2):
", flatten_t2)輸出
Tensor: tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]]) Size of Tensor: torch.Size([2, 2, 3]) Flatten tensor: tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) Flatten tensor (start_dim=0): tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) Flatten tensor (start_dim=1): tensor([[ 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12]]) Flatten tensor (start_dim=2): tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]])
示例 2
在此程式中,我們將張量展平為一維張量。我們還使用 **end_dim** 展平張量。
import torch
t = torch.tensor([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]])
print("Tensor:
", t)
print("Size of Tensor:", t.size())
# flatten the above tensor using end_dims
flatten_t = torch.flatten(t)
flatten_t0 = torch.flatten(t, end_dim=0)
flatten_t1 = torch.flatten(t, end_dim=1)
flatten_t2 = torch.flatten(t, end_dim=2)
# print the flatten tensors
print("Flatten tensor:
", flatten_t)
print("Flatten tensor (end_dim=0):
", flatten_t0)
print("Flatten tensor (end_dim=1):
", flatten_t1)
print("Flatten tensor (end_dim=2):
", flatten_t2)輸出
Tensor: tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]]) Size of Tensor: torch.Size([2, 2, 3]) Flatten tensor: tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) Flatten tensor (end_dim=0): tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]]) Flatten tensor (end_dim=1): tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]]) Flatten tensor (end_dim=2): tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
示例 3
在此程式中,我們將張量展平為一維張量。我們還使用 **start_dim** 和 **end_dim** 展平張量。
import torch
t = torch.empty(2,2,3,3).random_(30)
print("Tensor:
", t)
print("Size of Tensor:", t.size())
# flatten the above tensor using end_dims
flatten_t0 = torch.flatten(t, start_dim=2, end_dim=3)
# print the flatten tensors
print("Flatten tensor (start_dim=2,end_dim=3):
", flatten_t0)輸出
Tensor: tensor([[[[27., 13., 29.], [ 1., 23., 15.], [15., 7., 19.]], [[ 4., 14., 24.], [ 6., 4., 7.], [ 6., 18., 11.]]], [[[ 0., 27., 3.], [25., 12., 25.], [10., 23., 9.]], [[ 3., 1., 28.], [19., 7., 28.], [23., 14., 21.]]]]) Size of Tensor: torch.Size([2, 2, 3, 3]) Flatten tensor (start_dim=2,end_dim=3): tensor([[[27., 13., 29., 1., 23., 15., 15., 7., 19.], [ 4., 14., 24., 6., 4., 7., 6., 18., 11.]], [[ 0., 27., 3., 25., 12., 25., 10., 23., 9.], [ 3., 1., 28., 19., 7., 28., 23., 14., 21.]]])
廣告
資料結構
網路
關係資料庫管理系統
作業系統
Java
iOS
HTML
CSS
Android
Python
C 程式設計
C++
C#
MongoDB
MySQL
Javascript
PHP