如何在 PyTorch 中建立影像網格?


torchvision.utils 包提供了 make_grid() 函式來建立影像網格。影像應為 torch 張量。它接受形狀為 (B ☓ C ☓ H ☓ W) 的 4D 小批次張量或大小相同的張量影像列表。

  • 這裡,B 是批大小,C 是影像中通道的數量,HW 分別是高度和寬度。

  • 所有影像的 H ☓ W 應該相同。

此函式的輸出是一個包含影像網格的 torch 張量。我們可以使用 nrow 引數指定一行中的影像數量。我們還有許多其他引數來控制網格輸出。要視覺化影像網格,我們首先將整個網格轉換為 PIL 影像

語法

torchvision.utils.make_grid(tensor)

引數

  • tensor - 張量或張量列表。形狀為 (B x C x H x W) 的 4D 小批次張量或大小相同的影像列表。

輸出

它返回一個包含影像網格的 torch 張量。

步驟

  • 匯入所需的庫。在以下所有示例中,所需的 Python 庫為 torchtorchvision。確保您已安裝它們。

import torch
import torchvision
from torchvision.io import read_image
from torchvision.utils import make_grid
  • 使用 image_read() 函式讀取多個 JPEGPNG 影像。使用影像型別(.jpg.png)指定完整影像路徑。此函式的輸出是一個大小為 [image_channels, image_height, image_width] 的 torch 張量。

img1 = read_image('elephant.jpg')
img2 = read_image('cat.jpg')
img3 = read_image('dog.jpg')
  • 使用 make_grid() 函式建立讀取為 torch 張量的輸入影像網格。指定 nrow 以在網格中每行顯示的影像數量。

grid = make_grid([img1, img2, img3], nrow=3)
  • 將網格張量轉換為 PIL 影像並顯示它。

img = torchvision.transforms.ToPILImage()(grid)
img.show()

示例 1

在此 Python 程式中,我們讀取三個輸入影像並建立這些影像的網格。

import torch
import torchvision
from torchvision.io import read_image
from torchvision.utils import make_grid

# read images
img1 = read_image('elephant.jpg')
img2 = read_image('cat.jpg')
img3 = read_image('dog.jpg')
print("size of img1:", img1.size())
print("size of img2:", img2.size())
print("size of img3:", img3.size())

# make grid
grid = make_grid([img1, img2, img3])
print("size of grid:", grid.size())

# print("grid:
", grid) img = torchvision.transforms.ToPILImage()(grid) img.show()

輸出

size of img1: torch.Size([3, 466, 700])
size of img2: torch.Size([3, 466, 700])
size of img3: torch.Size([3, 466, 700])
size of grid: torch.Size([3, 470, 2108])

示例 2

在以下 Python 程式中,我們讀取四個輸入影像並建立這些影像的網格。我們將 nrow 設定為 2,以便網格每行有兩個影像。

# Import the required library
import torch
import torchvision
from torchvision.io import read_image
from torchvision.utils import make_grid

# read images
img1 = read_image('elephant.jpg')

# img1 = read_image('car.jpg')
print("Size of image:",img1.size())
img2 = read_image('cat.jpg')
img3 = read_image('dog.jpg')
img4 = read_image('leopard.jpg')

# make grid
grid = make_grid([img1, img2, img3, img4], nrow = 2)
print("size of grid:", grid.size())

# print("grid:
", grid) img = torchvision.transforms.ToPILImage()(grid) img.show()

輸出

Size of image: torch.Size([3, 466, 700])
size of grid: torch.Size([3, 938, 1406])

更新於: 2022年1月20日

4K+ 閱讀量

啟動你的 職業生涯

透過完成課程獲得認證

開始學習
廣告