如何在 PyTorch 中使用零填充輸入張量邊界?
**torch.nn.ZeroPad2D()** 使用零填充輸入張量邊界。它採用填充大小 (**padding**) 作為引數。填充大小可以是整數或元組。所有邊界或每個邊界的填充可能相同。
**padding** 可以是整數或 **(left, right, top, bottom)** 格式的元組。如果它是整數,則所有邊界的填充都相同。填充張量的**高度**增加 **top+bottom**,而填充張量的**寬度**增加 **left+right**。它不會更改通道大小或批次大小。填充通常在池化層之後在卷積神經網路 (CNN) 中使用,以保持輸入大小。
語法
torch.nn.ZeroPad2D(padding)
引數
**padding** – 期望的填充大小。整數或 (**left, right, top, bottom**) 格式的元組。
輸入張量的大小必須為 3D 或 4D,格式分別為 (**C, H, W**) 或 (**N, C, H, W**) ,其中 **N, C, H, W** 分別表示小批次大小、通道數、高度和寬度。
步驟
我們可以使用以下步驟用零填充輸入張量邊界:
匯入所需的庫。在以下所有示例中,所需的 Python 庫是 **torch**。確保您已安裝它。
import torch
定義輸入張量。我們定義一個 4D 張量如下。
input = torch.randn(2, 1, 3, 3)
定義填充大小並將其傳遞給 **torch.nn.ZeroPad2D()** 並建立一個例項 **pad** 以用零填充張量。填充大小可以相同或不同。
padding = (2,1) pad = nn.ZeroPad2d(padding)
使用上面建立的例項 **pad** 用零填充輸入張量。
output = pad(input)
列印最終填充的張量。
print("Padded Ternsor:
", output)示例 1
在下面的 Python 示例中,我們使用整數填充大小 2,即 **padding=2**,用零填充 3D 和 4D 張量。
# Import the required library
import torch
import torch.nn as nn
# define 3D tensor (C,H,W)
input = torch.tensor([[[ 1, 2],[ 3, 4]]])
print("Input Tensor:
",input)
# define padding same for all sides (left, right, top, bottom)
pad = nn.ZeroPad2d(2)
# pad the input tensor
output = pad(input)
print("Padded Ternsor:
", output)
# define 4D tensor (N,C,H,W)->for a batch of N tensors
input = torch.randn(2, 1, 3, 3)
print("Input Tensor:
",input)
# define padding same for all sides (left, right, top, bottom)
pad = nn.ZeroPad2d(2)
# pad the input tensor
output = pad(input)
print("Padded Tensor:
", output)輸出
Input Tensor: tensor([[[1, 2], [3, 4]]]) Padded Tensor: tensor([[[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 0, 0], [0, 0, 3, 4, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]]) Input Tensor: tensor([[[[-0.8336, -0.7609, 2.2278], [-0.5882, -1.2273, 0.3331], [ 2.1541, -0.0235, -0.4785]]], [[[ 0.8795, 2.6868, 1.2850], [-1.6687, -0.8479, 0.3797], [-1.5313, 0.5221, -1.5769]]]]) Padded Tensor: tensor([[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, -0.8336, -0.7609, 2.2278, 0.0000, 0.0000], [ 0.0000, 0.0000, -0.5882, -1.2273, 0.3331, 0.0000, 0.0000], [ 0.0000, 0.0000, 2.1541, -0.0235, -0.4785, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]], [[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.8795, 2.6868, 1.2850, 0.0000, 0.0000], [ 0.0000, 0.0000, -1.6687, -0.8479, 0.3797, 0.0000, 0.0000], [ 0.0000, 0.0000, -1.5313, 0.5221, -1.5769, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
示例 2
在下面的 Python 示例中,我們使用對輸入張量所有邊界具有不同填充大小的填充大小來填充 3D 和 4D 張量。
# Import the required library
import torch
import torch.nn as nn
# define 3D tensor (C,H,W)
input = torch.tensor([[[ 1, 2],[ 3, 4]]])
print("Input Tensor:
",input)
# define padding different for different sides
padding = (2,1,2,1)
pad = nn.ZeroPad2d(padding)
# pad the input tensor
output = pad(input)
print("Padded Ternsor:
", output)
input = torch.tensor([[[ 1, 2],[ 3, 4]]])
print("Input Tensor:
",input)
# define padding different for left and right sides
padding = (2,1)
pad = nn.ZeroPad2d(padding)
# pad the input tensor
output = pad(input)
print("Padded Ternsor:
", output)
# define 4D tensor (N,C,H,W)->for a batch of N tensors
input = torch.tensor([[[ 1, 2],[ 3, 4]],[[ 1, 2],[ 3, 4]]])
print("Input Tensor:
",input)
# define padding different for different sides
padding = (2,2,1,1)
pad = nn.ZeroPad2d(padding)
# pad the input tensor
output = pad(input)
print("Padded Ternsor:
", output)輸出
Input Tensor: tensor([[[1, 2], [3, 4]]]) Padded Tensor: tensor([[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 1, 2, 0], [0, 0, 3, 4, 0], [0, 0, 0, 0, 0]]]) Input Tensor: tensor([[[1, 2], [3, 4]]]) Padded Ternsor: tensor([[[0, 0, 1, 2, 0], [0, 0, 3, 4, 0]]]) Input Tensor: tensor([[[1, 2], [3, 4]], [[1, 2], [3, 4]]]) Padded Tensor: tensor([[[0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 0, 0], [0, 0, 3, 4, 0, 0], [0, 0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 0, 0], [0, 0, 3, 4, 0, 0], [0, 0, 0, 0, 0, 0]]])
資料結構
網路
關係資料庫管理系統 (RDBMS)
作業系統
Java
iOS
HTML
CSS
Android
Python
C語言程式設計
C++
C#
MongoDB
MySQL
Javascript
PHP