如何在 PyTorch 中建立影像網格?
torchvision.utils 包提供了 make_grid() 函式來建立影像網格。影像應為 torch 張量。它接受形狀為 (B ☓ C ☓ H ☓ W) 的 4D 小批次張量或大小相同的張量影像列表。
這裡,B 是批大小,C 是影像中通道的數量,H 和 W 分別是高度和寬度。
所有影像的 H ☓ W 應該相同。
此函式的輸出是一個包含影像網格的 torch 張量。我們可以使用 nrow 引數指定一行中的影像數量。我們還有許多其他引數來控制網格輸出。要視覺化影像網格,我們首先將整個網格轉換為 PIL 影像。
語法
torchvision.utils.make_grid(tensor)
引數
tensor - 張量或張量列表。形狀為 (B x C x H x W) 的 4D 小批次張量或大小相同的影像列表。
輸出
它返回一個包含影像網格的 torch 張量。
步驟
匯入所需的庫。在以下所有示例中,所需的 Python 庫為 torch 和 torchvision。確保您已安裝它們。
import torch import torchvision from torchvision.io import read_image from torchvision.utils import make_grid
使用 image_read() 函式讀取多個 JPEG 或 PNG 影像。使用影像型別(.jpg 或 .png)指定完整影像路徑。此函式的輸出是一個大小為 [image_channels, image_height, image_width] 的 torch 張量。
img1 = read_image('elephant.jpg')
img2 = read_image('cat.jpg')
img3 = read_image('dog.jpg')使用 make_grid() 函式建立讀取為 torch 張量的輸入影像網格。指定 nrow 以在網格中每行顯示的影像數量。
grid = make_grid([img1, img2, img3], nrow=3)
將網格張量轉換為 PIL 影像並顯示它。
img = torchvision.transforms.ToPILImage()(grid) img.show()
示例 1
在此 Python 程式中,我們讀取三個輸入影像並建立這些影像的網格。
import torch
import torchvision
from torchvision.io import read_image
from torchvision.utils import make_grid
# read images
img1 = read_image('elephant.jpg')
img2 = read_image('cat.jpg')
img3 = read_image('dog.jpg')
print("size of img1:", img1.size())
print("size of img2:", img2.size())
print("size of img3:", img3.size())
# make grid
grid = make_grid([img1, img2, img3])
print("size of grid:", grid.size())
# print("grid:
", grid)
img = torchvision.transforms.ToPILImage()(grid)
img.show()輸出
size of img1: torch.Size([3, 466, 700]) size of img2: torch.Size([3, 466, 700]) size of img3: torch.Size([3, 466, 700]) size of grid: torch.Size([3, 470, 2108])

示例 2
在以下 Python 程式中,我們讀取四個輸入影像並建立這些影像的網格。我們將 nrow 設定為 2,以便網格每行有兩個影像。
# Import the required library
import torch
import torchvision
from torchvision.io import read_image
from torchvision.utils import make_grid
# read images
img1 = read_image('elephant.jpg')
# img1 = read_image('car.jpg')
print("Size of image:",img1.size())
img2 = read_image('cat.jpg')
img3 = read_image('dog.jpg')
img4 = read_image('leopard.jpg')
# make grid
grid = make_grid([img1, img2, img3, img4], nrow = 2)
print("size of grid:", grid.size())
# print("grid:
", grid)
img = torchvision.transforms.ToPILImage()(grid)
img.show()輸出
Size of image: torch.Size([3, 466, 700]) size of grid: torch.Size([3, 938, 1406])

廣告
資料結構
網路
關係資料庫管理系統
作業系統
Java
iOS
HTML
CSS
Android
Python
C 語言程式設計
C++
C#
MongoDB
MySQL
Javascript
PHP