如何在PyTorch中載入計算機視覺資料集?


PyTorch中有許多與計算機視覺任務相關的可用資料集。**torch.utils.data.Dataset** 提供不同型別的資料集。**torchvision.datasets** 是 **torch.utils.data.Dataset** 的子類,包含許多與影像和影片相關的資料集。PyTorch還提供了一個 **torch.utils.data.DataLoader**,用於從資料集中載入多個樣本。

步驟

我們可以使用以下步驟載入計算機視覺資料集:

  • 匯入所需的庫。在以下所有示例中,所需的Python庫為**torch**、**Matplotlib** 和 **torchvision**。確保您已安裝它們。

import torch
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
  • 我們使用 **datasets.CIFAR10()** 載入 CIFAR10 訓練和測試資料集,引數 **train=True** 用於訓練資料集,**train=False** 用於測試資料集。

root="data",
train=True,
download=True,
transform=ToTensor()
  • 定義訓練資料載入器 (**trainloader**) 和測試資料載入器 (**testloader**)。指定 **batch_size**。設定 **Shuffle=True** 以獲得隨機排列的影像。還可以訪問類標籤名稱。

  • 從訓練或測試資料集中獲取一些隨機影像和標籤。

dataiter = iter(trainloader)
images, labels = dataiter.next()
  • 使用標籤視覺化獲得的影像。

示例 1

在下面的 Python 程式中,我們載入 CIFAR10 訓練和測試資料集。

# Import the required libraries
import torch
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor

# define batch size
batch_size = 4

# download CIFAR10 training and test datasets
training_data = datasets.CIFAR10(
   root="data",
   train=True,
   download=True,
   transform=ToTensor()
)

test_data = datasets.CIFAR10(
   root="data",
   train=False,
   download=True,
   transform=ToTensor()
)

# define train and test dataloader
trainloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=2)

testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)

# access names of the labels
label_names = training_data.classes

# display details about the dataset
print("label_names:
", label_names) print("class label name to index:
", training_data.class_to_idx) print("Shape of training data:
", training_data.data.shape ) print("Shape of test data:
", test_data.data.shape )

輸出

Files already downloaded and verified
Files already downloaded and verified
label_names:
   ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog',
      'frog', 'horse', 'ship', 'truck']
class label name to index:
   {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer':
      4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
Shape of training data:
   (50000, 32, 32, 3)
Shape of test data:
   (10000, 32, 32, 3)

示例 2

在這個 Python 程式中,我們載入 CIFAR10 資料集。我們還視覺化一些帶有標籤名稱的隨機影像。

import torch
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

batch_size = 4
training_data = datasets.CIFAR10(
   root="data",
   train=True,
   download=True,
   transform=ToTensor()
)

test_data = datasets.CIFAR10(
   root="data",
   train=False,
   download=True,
   transform=ToTensor()
)

trainloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=False, num_workers=2)

testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)

label_names = training_data.classes

# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# display random images
# define figure
fig=plt.figure(figsize=(8, 5))
columns, rows = batch_size, 1

# visualize these random images
for i in range(1, columns*rows +1):
   fig.add_subplot(rows, columns, i)
   plt.imshow(images[i-1].numpy().transpose(1,2,0))
   plt.xticks([])
   plt.yticks([])
   plt.title(label_names[labels[i-1]])
plt.show()

輸出

Files already downloaded and verified
Files already downloaded and verified

更新於:2022年1月25日

725 次瀏覽

啟動您的職業生涯

透過完成課程獲得認證

開始學習
廣告
© . All rights reserved.