PyTorch – torchvision.transforms – RandomResizedCrop()


RandomResizedCrop() 變換會裁剪原始輸入影像的隨機區域。此裁剪大小是隨機選擇的,最後裁剪後的影像將調整為給定大小。RandomResizedCrop() 變換是torchvision.transforms 模組提供的眾多變換之一。此模組包含許多重要的變換,可用於對影像資料執行不同型別的操作。

RandomResizedCrop() 接受 PIL 和張量影像。張量影像是一個 PyTorch 張量,形狀為[..., H, W],其中 ... 表示任意數量的維度,H 是影像高度,W 是影像寬度。如果影像既不是 PIL 影像也不是張量影像,則首先將其轉換為張量影像,然後應用變換。

語法

torchvision.transforms.RandomResizedCrop(size)(img)

其中 size 是所需的裁剪大小。size 是一個類似於(h, w)的序列,其中 h 和 w 分別是裁剪影像的高度和寬度。如果size 是一個int,則裁剪後的影像為正方形影像。

它返回使用給定大小調整大小的裁剪影像。

步驟

我們可以使用以下步驟裁剪輸入影像的隨機部分並將其調整為給定大小:

  • 匯入所需的庫。在以下所有示例中,所需的 Python 庫為torch、Pillowtorchvision。確保您已安裝它們。

import torch
import torchvision
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt
  • 讀取輸入影像。輸入影像為 PIL 影像或形狀為 [..., H, W] 的 torch 張量。

img = Image.open('baseball.png')
  • 定義一個變換,以裁剪輸入影像上的隨機部分,然後調整為給定大小。此處給定大小為 (150,250) 用於矩形裁剪,250 用於正方形裁剪。根據您的需要更改裁剪大小。

# transform for rectangular crop
transform = T.RandomResizedCrop((150,250))
# transform for square crop
transform = T.RandomResizedCrop(250)
  • 將上述定義的變換應用於輸入影像,以裁剪輸入影像上的隨機部分,然後將其調整為給定大小。

cropped_img = transform(img)
  • 顯示裁剪後的影像,然後顯示調整大小的影像

cropped_img.show()

輸入影像

此影像用作以下所有示例中的輸入。

示例 1

在此程式中,裁剪輸入影像的隨機部分,然後將其大小調整為 (150, 250)。

# import required libraries
import torch
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt

# read the input image
img = Image.open('baseball.png')

# define a transform to crop a random portion of an image
# and resize it to given size
transform = T.RandomResizedCrop(size=(350,600))

# apply above defined transform to the input image
img = transform(img)

# display the cropped image
img.show()

輸出

它將產生以下輸出:

示例 2

import torch
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt

img = Image.open('baseball.png')
transform = T.RandomResizedCrop(size = (200,150), scale=(0.08,
1.0), ratio=(0.75, 1.3333333333333333))
imgs = [transform(img) for _ in range(4)]
fig = plt.figure(figsize=(7,3))
rows, cols = 2,2
for j in range(0, len(imgs)):
   fig.add_subplot(rows, cols, j+1)
   plt.imshow(imgs[j])
   plt.xticks([])
   plt.yticks([])
plt.show()

輸出

它將產生以下輸出:

更新於: 2022-01-06

2K+ 次檢視

開啟你的 職業生涯

透過完成課程獲得認證

立即開始
廣告

© . All rights reserved.