如何在PyTorch中載入計算機視覺資料集?
PyTorch中有許多與計算機視覺任務相關的可用資料集。**torch.utils.data.Dataset** 提供不同型別的資料集。**torchvision.datasets** 是 **torch.utils.data.Dataset** 的子類,包含許多與影像和影片相關的資料集。PyTorch還提供了一個 **torch.utils.data.DataLoader**,用於從資料集中載入多個樣本。
步驟
我們可以使用以下步驟載入計算機視覺資料集:
匯入所需的庫。在以下所有示例中,所需的Python庫為**torch**、**Matplotlib** 和 **torchvision**。確保您已安裝它們。
import torch import torchvision from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt
我們使用 **datasets.CIFAR10()** 載入 CIFAR10 訓練和測試資料集,引數 **train=True** 用於訓練資料集,**train=False** 用於測試資料集。
root="data", train=True, download=True, transform=ToTensor()
定義訓練資料載入器 (**trainloader**) 和測試資料載入器 (**testloader**)。指定 **batch_size**。設定 **Shuffle=True** 以獲得隨機排列的影像。還可以訪問類標籤名稱。
從訓練或測試資料集中獲取一些隨機影像和標籤。
dataiter = iter(trainloader) images, labels = dataiter.next()
使用標籤視覺化獲得的影像。
示例 1
在下面的 Python 程式中,我們載入 CIFAR10 訓練和測試資料集。
# Import the required libraries
import torch
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor
# define batch size
batch_size = 4
# download CIFAR10 training and test datasets
training_data = datasets.CIFAR10(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.CIFAR10(
root="data",
train=False,
download=True,
transform=ToTensor()
)
# define train and test dataloader
trainloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)
# access names of the labels
label_names = training_data.classes
# display details about the dataset
print("label_names:
", label_names)
print("class label name to index:
", training_data.class_to_idx)
print("Shape of training data:
", training_data.data.shape )
print("Shape of test data:
", test_data.data.shape )輸出
Files already downloaded and verified
Files already downloaded and verified
label_names:
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog',
'frog', 'horse', 'ship', 'truck']
class label name to index:
{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer':
4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
Shape of training data:
(50000, 32, 32, 3)
Shape of test data:
(10000, 32, 32, 3)示例 2
在這個 Python 程式中,我們載入 CIFAR10 資料集。我們還視覺化一些帶有標籤名稱的隨機影像。
import torch import torchvision from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt batch_size = 4 training_data = datasets.CIFAR10( root="data", train=True, download=True, transform=ToTensor() ) test_data = datasets.CIFAR10( root="data", train=False, download=True, transform=ToTensor() ) trainloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=False, num_workers=2) testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2) label_names = training_data.classes # get some random training images dataiter = iter(trainloader) images, labels = dataiter.next() # display random images # define figure fig=plt.figure(figsize=(8, 5)) columns, rows = batch_size, 1 # visualize these random images for i in range(1, columns*rows +1): fig.add_subplot(rows, columns, i) plt.imshow(images[i-1].numpy().transpose(1,2,0)) plt.xticks([]) plt.yticks([]) plt.title(label_names[labels[i-1]]) plt.show()
輸出
Files already downloaded and verified Files already downloaded and verified

廣告
資料結構
網路
關係資料庫管理系統 (RDBMS)
作業系統
Java
iOS
HTML
CSS
Android
Python
C語言程式設計
C++
C#
MongoDB
MySQL
Javascript
PHP