PyTorch – 五裁剪變換


為了將給定影像裁剪成四個角和中心裁剪,我們應用**FiveCrop()**變換。這是torchvision.transforms模組提供的眾多變換之一。此模組包含許多重要的變換,可用於對影像資料執行不同型別的操作。

**FiveCrop()**變換接受PIL影像和張量影像。張量影像是一個形狀為**[C, H, W]**的torch張量,其中C是通道數,H是影像高度,W是影像寬度。如果影像既不是PIL影像也不是張量影像,則我們首先將其轉換為張量影像,然後應用**FiveCrop**變換。

語法

torchvision.transforms.FiveCrop(size)

其中size是所需的裁剪大小。size是一個類似於**(h, w)**的序列,其中h和w分別是每個裁剪影像的高度和寬度。如果**size**是**int**,則裁剪的影像是正方形的。

它返回一個包含五個裁剪影像的元組,四個角影像和一箇中心影像。

步驟

我們可以使用以下步驟將影像裁剪成四個影像和一個給定大小的中心裁剪:

  • 匯入所需的庫。在以下所有示例中,所需的Python庫是**torch、Pillow**和**torchvision**。確保您已安裝它們。

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
  • 讀取輸入影像。輸入影像為PIL影像或torch張量。

img = Image.open('recording.jpg')
  • 定義一個變換,將影像裁剪成四個角和中心裁剪。對於矩形裁剪,裁剪大小設定為(150, 300),對於正方形裁剪,裁剪大小設定為250。根據您的需要更改裁剪大小。

# transform for rectangular crop
transform = transforms.FiveCrop((200,250))

# transform for square crop
transform = transforms.FiveCrop(250)
  • 將上面定義的變換應用於輸入影像,以將其裁剪成四個角和中心裁剪。

img = transform(img)
  • 顯示所有五個裁剪影像。

輸入影像

我們將在以下兩個示例中使用此影像。

示例1

在下面的Python3程式中,我們裁剪四個角和一箇中心裁剪。五個裁剪影像都是矩形的。

# import required libraries
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

# Read the image
img = Image.open('recording.jpg')

# define a transform to crop the image into four
# corners and the central crop
transform = transforms.FiveCrop((150, 300))

# apply the above transform on the image
imgs = transform(img)

# This transform returns a tuple of 5 images
print(type(imgs))
print("Total cropped images:",len(imgs))

輸出

<class 'tuple'>
Total cropped images: 5

示例2

在下面的Python3程式中,我們裁剪四個角和一箇中心裁剪。五個裁剪影像都是正方形的。

# import required libraries
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

# Read the image
img = Image.open('recording.jpg')

# define a transform to crop the image into four
# corners and the central crop
transform = transforms.FiveCrop(200)

# apply the above transform on the image
imgs = transform(img)

# Define a figure of size (8, 8)
fig=plt.figure(figsize=(8, 8))

# Define row and cols in the figure
rows, cols = 1, 5

# Display all 5 cropped images
for j in range(0, cols*rows):
   fig.add_subplot(rows, cols, j+1)
   plt.imshow(imgs[j])
   plt.xticks([])
   plt.yticks([])
plt.show()

輸出

它將產生以下輸出:

更新於:2022年1月6日

767 次瀏覽

啟動您的職業生涯

完成課程獲得認證

開始學習
廣告
© . All rights reserved.