PyTorch 中基於索引的操作


基於索引的操作在操作和訪問張量內特定元素或資料子集方面發揮著至關重要的作用。PyTorch 是一款流行的開源深度學習框架,它提供了強大的機制來高效地執行此類操作。透過利用基於索引的操作,開發人員可以沿張量的各個維度提取、修改和重新排列資料。

張量基礎

PyTorch 張量是多維陣列,可以儲存各種型別的數值資料,例如浮點數、整數或布林值。張量是 PyTorch 中的基本資料結構,是構建和操作神經網路的基礎。

要在 PyTorch 中建立張量,我們可以使用 torch.Tensor 類或 PyTorch 提供的各種工廠函式,例如 torch.zeros、torch.ones 或 torch.rand。讓我們來看幾個例子 

import torch
# Create a tensor of zeros with shape (3, 2)
zeros_tensor = torch.zeros(3, 2)
print(zeros_tensor)

# Create a tensor of ones with shape (2, 3)
ones_tensor = torch.ones(2, 3)
print(ones_tensor)

# Create a random tensor with shape (4, 4)
rand_tensor = torch.rand(4, 4)
print(rand_tensor)

除了張量的形狀之外,我們還可以使用 dtype 屬性檢查其資料型別。PyTorch 支援多種資料型別,包括 torch.float32、torch.float64、torch.int8、torch.int16、torch.int32、torch.int64 和 torch.bool。預設資料型別是 torch.float32。要指定特定的資料型別,我們可以在建立張量時傳遞 dtype 引數。

# Create a tensor of ones with shape (2, 2) and data type torch.float64
ones_double_tensor = torch.ones(2, 2, dtype=torch.float64)
print(ones_double_tensor)

除了從頭開始建立張量之外,我們還可以使用 torch.tensor 函式將現有的資料結構(例如列表或 NumPy 陣列)轉換為 PyTorch 張量。這允許與其他庫無縫整合,並簡化深度學習任務的資料準備工作。

import numpy as np

# Create a NumPy array
numpy_array = np.array([[1, 2, 3], [4, 5, 6]])

# Convert the NumPy array to a PyTorch tensor
tensor_from_numpy = torch.tensor(numpy_array)
print(tensor_from_numpy)

PyTorch 中的索引和切片

索引和切片操作在訪問 PyTorch 中張量的特定元素或子集方面起著至關重要的作用。它們允許我們有效地檢索和操作資料,使處理大型張量或提取有意義的資訊以進行進一步分析變得更容易。在本節中,我們將探討 PyTorch 中索引和切片的基礎知識。

基本索引

在 PyTorch 中,我們可以透過為每個維度提供索引來訪問張量的單個元素。索引從每個維度的第一個元素開始,為 0。讓我們來看一些例子 

import torch

# Create a tensor
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])

# Access the element at row 0, column 1
element = tensor[0, 1]
print(element)  # Output: tensor(2)

# Access the element at row 1, column 2
element = tensor[1, 2]
print(element)  # Output: tensor(6)

我們還可以使用負索引從維度的末尾訪問元素。例如,-1 指的是最後一個元素,-2 指的是倒數第二個元素,依此類推。

import torch

# Create a tensor
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])

# Access the last element
element = tensor[-1, -1]
print(element)  # Output: tensor(6)

切片

除了訪問單個元素之外,PyTorch 還支援切片操作來提取張量的子集。切片允許我們指定每個維度上的範圍或間隔,以一次檢索多個元素。讓我們看看切片是如何工作的 

import torch

# Create a tensor
tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# Slice the first row
row_slice = tensor[0, :]
print(row_slice)  # Output: tensor([1, 2, 3])

# Slice the first column
column_slice = tensor[:, 0]
print(column_slice)  # Output: tensor([1, 4, 7])

# Slice a submatrix
submatrix_slice = tensor[1:, 1:]
print(submatrix_slice)  # Output: tensor([[5, 6], [8, 9]])

在上面的示例中,我們使用冒號 (:) 表示我們想要包含特定維度上的所有元素。這使我們能夠同時跨行、列或兩者進行切片。

使用整數和布林掩碼進行索引

除了常規索引和切片之外,PyTorch 還提供了使用整數陣列或布林掩碼的更高階的索引技術。這些技術提供了更大的靈活性和對我們想要訪問或修改的元素的控制。

我們可以使用整數陣列來指定我們想要從維度中選擇的索引。讓我們看一個例子 

import torch

# Create a tensor
tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# Create an integer array of indices
indices = torch.tensor([0, 2])

# Select specific rows using integer array indexing
selected_rows = tensor[indices]
print(selected_rows)  # Output: tensor([[1, 2, 3], [7, 8, 9]])

高階索引技術

除了基本的索引和切片操作之外,PyTorch 還提供了高階索引技術,這些技術提供了更大的靈活性和對從張量中選擇元素的控制。在本節中,我們將探討這些技術以及如何在 PyTorch 中使用它們。

使用掩碼張量進行索引

PyTorch 中一個強大的索引技術涉及使用布林掩碼根據某些條件選擇元素。布林掩碼是一個與原始張量形狀相同的張量,其中每個元素都是 True 或 False,指示原始張量中的對應元素是否應該被選中。

讓我們看一個例子 –

import torch

# Create a tensor
tensor = torch.tensor([1, 2, 3, 4, 5])

# Create a boolean mask based on a condition
mask = tensor > 3

# Select elements based on the mask
selected_elements = tensor[mask]
print(selected_elements)  # Output: tensor([4, 5])

在這個例子中,我們透過應用條件 tensor > 3 建立了一個布林掩碼,它返回一個布林張量,指示 tensor 中的每個元素是否大於 3。然後我們使用這個掩碼來選擇 tensor 中僅滿足條件的元素,得到一個新的張量 [4, 5]。

省略號用於擴充套件切片

PyTorch 還提供了省略號 (...) 語法來執行擴充套件切片,這在處理更高維度的張量時特別有用。省略號允許我們在切片操作中表示多個冒號 (:),隱式地指示所有未明確提及的維度都包含在內。

讓我們考慮一個例子來說明它的用法 –

import torch

# Create a tensor of shape (2, 3, 4, 5)
tensor = torch.randn(2, 3, 4, 5)

# Use ellipsis for extended slicing
sliced_tensor = tensor[..., 1:3, :]
print(sliced_tensor.shape)  # Output: torch.Size([2, 3, 2, 5])

在這個例子中,省略號 ... 表示切片操作中未明確提及的所有維度。因此,tensor[..., 1:3, :] 從 tensor 中的所有維度選擇元素,除了第二個維度,它從第 1 個和第 2 個索引選擇元素。生成的切片張量形狀為 (2, 3, 2, 5)。

結論

PyTorch 中的基於索引的操作提供了一種靈活且有效的方式來訪問、修改和重新排列張量內的元素。透過利用基本索引、高階索引、布林索引和多維索引,開發人員可以輕鬆地執行細粒度的數 據操作、選擇和過濾任務。

更新於: 2023年8月14日

2K+ 閱讀量

開啟你的 職業生涯

透過完成課程獲得認證

開始學習
廣告
© . All rights reserved.