PyTorch – 如何在隨機位置裁剪影像?


要在隨機位置裁剪影像,我們應用**RandomCrop()**變換。這是**torchvision.transforms**模組提供的眾多重要變換之一。

**RandomCrop()**變換接受PIL影像和張量影像。張量影像是一個形狀為**[C, H, W]**的torch張量,其中C是通道數,H是影像高度,W是影像寬度。

如果影像既不是PIL影像也不是張量影像,那麼我們首先將其轉換為張量影像,然後應用**RandomCrop()**。

語法

torchvision.transforms.RandomCrop(size)(img)

其中size是所需的裁剪大小。size是一個類似於**(h,w)**的序列,其中**h**和**w**分別是裁剪影像的高度和寬度。如果**size**是**int**,則裁剪後的影像將是正方形影像。

它返回在給定大小的隨機位置裁剪的影像。

步驟

我們可以使用以下步驟在給定大小的隨機位置裁剪影像:

  • 匯入所需的庫。在以下所有示例中,所需的Python庫是**torch、Pillow**和**torchvision**。確保您已安裝它們。

import torch
import torchvision
import torchvision.transforms as T
from PIL import Image
  • 讀取輸入影像。輸入影像為PIL影像或torch張量。

img = Image.open('meteor.jpg')
  • 定義一個變換,以便在隨機位置裁剪影像。矩形裁剪的裁剪大小為(200,250),正方形裁剪的裁剪大小為250。根據您的需要更改裁剪大小。

# transform for rectangular crop
transform = T.RandomCrop((200,250))
# transform for square crop
transform = T.RandomCrop(250)
  • 將上述定義的變換應用於輸入影像,以便在隨機位置裁剪影像。

img = transform(img)
  • 視覺化裁剪後的影像。

img.show()

輸入影像

此影像用作以下所有示例中的輸入。

示例1

以下Python 3程式顯示瞭如何在隨機位置裁剪輸入PIL影像。

# import required libraries
import torch
import torchvision.transforms as T
from PIL import Image

# read the input image
img = Image.open('meteor.png')

# define transform to crop the image at
# random location
transform = T.RandomCrop((250,500))
img = transform(img)
img.show()

輸出

它將產生以下輸出:

示例2

import torch
import torchvision.transforms as T
from PIL import Image

img = Image.open('lena.jpg')
transform = T.RandomCrop((250,500), padding=50)
img = transform(img)
img.show()

輸出

它將產生以下輸出。請注意,填充是隨機的。

示例3

# import required libraries
import torch
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt

# read the input image
img = Image.open('meteor.png')

# define the transform with crop size
transform = T.RandomCrop((100,150))

# crop four images
imgs = [transform(img) for _ in range(4)]

# display these cropped images
fig = plt.figure(figsize=(7,3))
rows, cols = len(imgs),1
for j in range(0, len(imgs)):
   fig.add_subplot(rows, cols, j+1)
   plt.imshow(imgs[j])
   #plt.xticks([])
   #plt.yticks([])
plt.show()

輸出

它將產生以下輸出:

更新於:2022年1月6日

3K+ 次瀏覽

啟動你的職業生涯

完成課程獲得認證

開始學習
廣告
© . All rights reserved.