PyTorch – 如何使用均值和標準差對影像進行歸一化?
**Normalize()** 變換使用均值和標準差對影像進行歸一化。**torchvision.transforms** 模組提供了許多重要的變換,可用於對影像資料執行不同型別的操作。
**Normalize()** 僅接受任何大小的張量影像。張量影像是一個 torch 張量。張量影像可能具有 n 個通道。**Normalize()** 變換對每個通道的張量影像進行歸一化。
由於此變換僅支援張量影像,因此應先將 PIL 影像轉換為 torch 張量。應用 **Normalize()** 變換後,我們將歸一化的 torch 張量轉換為 PIL 影像。
步驟
我們可以使用以下步驟來使用均值和標準差對影像進行歸一化:
匯入所需的庫。在以下所有示例中,所需的 Python 庫為 **torch、Pillow** 和 **torchvision**。請確保您已安裝它們。
import torch import torchvision import torchvision.transforms as T from PIL import Image
讀取輸入影像。輸入影像可以是 PIL 影像或 torch 張量。如果輸入影像為 PIL 影像,請將其轉換為 torch 張量。
img = Image.open('sunset.jpg') # convert image to torch tensor imgTensor = T.ToTensor()(img)
定義一個變換,使用均值和標準差對影像進行歸一化。這裡,我們使用 ImageNet 資料集的均值和標準差。
transform = T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
將上面定義的變換應用於輸入影像以對影像進行歸一化。
normalized_imgTensor = transform(imgTensor)
將歸一化的張量影像轉換為 PIL 影像。
normalized_img = T.ToPILImage()(normalized_imgTensor)
顯示歸一化的影像。
normalized _img.show()
輸入影像
此影像用作以下所有示例中的輸入檔案。
示例 1
以下 Python 程式將輸入影像歸一化到均值和標準差。我們使用 ImageNet 資料集的均值和標準差。
# import required libraries import torch import torchvision.transforms as T from PIL import Image # Read the input image img = Image.open('sunset.jpg') # convert image to torch tensor imgTensor = T.ToTensor()(img) # define a transform to normalize the tensor transform = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # normalize the converted tensor using above defined transform normalized_imgTensor = transform(imgTensor) # convert the normalized tensor to PIL image normalized_img = T.ToPILImage()(normalized_imgTensor) # display the normalized PIL image normalized_img.show()
輸出
它將產生以下輸出:
示例 2
在此示例中,我們定義了一個 **Compose 變換** 來執行三個變換。
將 PIL 影像轉換為張量影像。
歸一化張量影像。
將歸一化的影像張量轉換為 PIL 影像。
# import required libraries import torch import torchvision.transforms as T from PIL import Image # read the input image img = Image.open('sunset.jpg') # define a transform to: # convert the PIL image to tensor # normalize the tensor # convert the tensor to PIL image transform = T.Compose([ T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), T.ToPILImage()]) # apply the above tensor on input image img = transform(img) img.show()
輸出
它將產生以下輸出:
廣告