在 Pytorch 中載入資料
每個機器學習專案都依賴於資料,由 Facebook 建立的著名開源機器學習工具包 PyTorch 也不例外。本手冊旨在簡化 PyTorch 中的資料載入過程,並幫助您儘快開始使用。
本文將重點介紹 PyTorch 的 DataLoader、Dataset 和 Transform 類。我們將透過一些實際示例來幫助您理解這些 PyTorch 核心概念,並簡化您的機器學習應用程式。
PyTorch 資料載入:簡要概述
為了匯入和準備資料,PyTorch 提供了一個強大且靈活的工具箱。三個關鍵要素是:
Dataset − 這個抽象類代表一個數據集,它允許以任何格式載入資料。只需要重寫兩個方法 __getitem__() 和 __len__()。
DataLoader − 它封裝了一個 Dataset,並提供對底層資料的快速訪問。它會自動構建批次、隨機打亂資料,並使用多執行緒並行載入資料。
Transforms − 這些是常見的影像修改。可以透過 Compose 將轉換連結在一起。這使您可以建立一個預處理操作管道,可以將其應用於載入的資料。
將資料載入到 PyTorch:示例
考慮一個影像集合,其中每個影像都表示為一個 3D NumPy 陣列,並且標籤與影像分開儲存。以下是如何將此資料新增到 PyTorch 的快速方法。
from torch.utils.data import Dataset, DataLoader
import numpy as np
class ImageDataset(Dataset):
def __init__(self, images, labels):
self.images = images
self.labels = labels
def __getitem__(self, index):
return self.images[index], self.labels[index]
def __len__(self):
return len(self.labels)
# Let's assume we have image data in NumPy arrays
images = np.random.rand(10000, 3, 32, 32)
labels = np.random.randint(0, 10, 10000)
dataset = ImageDataset(images, labels)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
我們在上述程式碼中建立了一個自定義的 Dataset 類。__len__ 函式返回影像的總數,而 __getitem__ 方法返回給定索引處的影像和標籤。然後,DataLoader 將包裝此 Dataset,它將處理批處理和資料隨機打亂。
在 PyTorch 中使用 Transforms
您可以使用轉換以靈活的方式預處理資料。例如,在基於影像的任務中,我們通常需要對資料進行歸一化、將其轉換為張量或使用資料增強技術。使用 PyTorch 的轉換模組,這些任務變得非常簡單。
from torchvision import transforms
# Define a transform to normalize the data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Apply the transform to all images in the dataset
class ImageDataset(Dataset):
def __init__(self, images, labels, transform=None):
self.images = images
self.labels = labels
self.transform = transform
def __getitem__(self, index):
image = self.images[index]
if self.transform:
image = self.transform(image)
return image, self.labels[index]
def __len__(self):
return len(self.labels)
dataset = ImageDataset(images, labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
在此示例中,轉換在歸一化後將影像資料轉換為 PyTorch 張量。當我們例項化我們的 ImageDataset 時,我們將此轉換傳遞給它,然後它將在 '__getitem__' 方法中應用於每個影像。
從 CSV 檔案載入資料
對於諸如迴歸分析和分類之類的操作,通常需要載入來自 CSV 檔案的資料。讓我們使用 pandas 載入 CSV 檔案、處理資料並構建 PyTorch DataLoader。
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import TensorDataset
# Load the data from a CSV file
df = pd.read_csv('data.csv')
# Convert categorical data to numerical data
le = LabelEncoder()
df['category'] = le.fit_transform(df['category'])
# Split the data into inputs and targets
inputs = df.drop('category', axis=1).values
targets = df['category'].values
# Convert to PyTorch Dataset
dataset = TensorDataset(torch.from_numpy(inputs), torch.from_numpy(targets))
# Wrap in a DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
在此示例中,pandas 用於從 CSV 檔案載入資料。然後,Scikit-Learn 中的 LabelEncoder 函式用於將分類資料轉換為數值資料。輸入和目標被分割,它們被轉換為 PyTorch 張量,並建立了一個 TensorDataset。最後,我們建立了一個 DataLoader 來處理批處理和隨機打亂。
結論
在 PyTorch 中建立有效的機器學習模型,資料載入是一項基本技能。使用 PyTorch 的 DataLoader、Dataset 和 Transform 類,這項工作變得更簡單、更高效。無論您是在處理表格資料還是影像資料,都可以修改這些類以滿足您的需求。
資料結構
網路
關係型資料庫管理系統 (RDBMS)
作業系統
Java
iOS
HTML
CSS
Android
Python
C 程式設計
C++
C#
MongoDB
MySQL
Javascript
PHP