如何正確訪問三維 Pytorch 張量中的元素?


PyTorch 是一個流行的開源機器學習框架,它在 CPU 和 GPU 上都提供了高效的張量運算。在 PyTorch 中,張量是一個多維陣列,它是用於儲存和操作資料的基本資料結構。

在此背景下,三維張量是一個具有三個維度的張量,它可以表示為一個類似立方體的結構,具有行、列和深度。要訪問三維 PyTorch 張量中的元素,您需要知道它的維度以及要訪問的元素的索引。

張量的索引使用方括號 ([]) 指定,您可以使用一個或多個以逗號分隔的索引來訪問張量中的元素。索引值從 0 開始,最後一個索引值總是小於該維度的大小。

現在我們已經從理論上了解了如何訪問三維張量中的元素,讓我們來看一些例子。

示例 1

訪問三維張量中的特定元素。

考慮以下程式碼。

import torch

# create a 3D tensor with dimensions 2x3x4
tensor_3d = torch.tensor([
    [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
    [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
])

# access the element at row 1, column 2, and depth 3
element = tensor_3d[1, 2, 3]

# print the element
print(element)

解釋

  • 我們首先建立一個維度為 2x3x4 的三維張量,並用一些值初始化它。

  • 然後我們使用方括號訪問第 1 行、第 2 列和深度 3 的元素。

  • 最後,我們列印元素的值,它是 20。

輸出

20

示例 2

從三維張量中提取子張量

考慮以下程式碼。

import torch

# create a 3D tensor with dimensions 2x3x4
tensor_3d = torch.tensor([
    [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
    [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
])

# extract a sub-tensor starting at row 0, column 1, and depth 1
sub_tensor = tensor_3d[:, 1:, 1:]

# print the sub-tensor
print(sub_tensor)

解釋

  • 我們首先建立一個維度為 2x3x4 的三維張量,並用一些值初始化它。

  • 然後我們使用切片從第 0 行、第 1 列和深度 1 開始提取子張量。

  • 子張量包括從第 0 行到末尾的所有元素、從第 1 列到末尾的所有元素以及從深度 1 到末尾的所有元素。

  • 最後,我們列印子張量,其中包含值 6、7、8、10、11、12、18、19、20、22、23 和 24。

輸出

tensor([[[ 6,  7,  8],
         [10, 11, 12]],

        [[18, 19, 20],
         [22, 23, 24]]])

示例 3

使用布林掩碼訪問三維張量中的特定元素

考慮以下程式碼。

import torch

# create a 3D tensor with dimensions 2x3x4
tensor_3d = torch.tensor([
    [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
    [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
])

# create a boolean mask with the same dimensions as the tensor
mask = tensor_3d % 2 == 0

# use the mask to access specific elements in the tensor
even_elements = tensor_3d[mask]

# print the even elements
print(even_elements)

解釋

  • 我們首先建立一個維度為 2x3x4 的三維張量,並用一些值初始化它。

  • 然後我們建立一個與張量維度相同的布林掩碼,如果張量中對應的元素為偶數,則值為 True,否則為 False。

  • 我們使用掩碼透過將掩碼作為索引傳遞給張量來訪問張量中的特定元素。這將返回一個包含張量中所有偶數元素的一維張量。

  • 最後,我們列印偶數元素,它們是 2、4、6、8、10、12、14、16、18、20、22 和 24。

輸出

tensor([ 2,  4,  6,  8, 10, 12, 14, 16, 18, 20, 22, 24])

結論

總之,訪問三維 PyTorch 張量中的元素是在 PyTorch 中處理多維資料的重要技能。在本文中,我們學習瞭如何使用索引和切片訪問三維張量中的特定元素,以及如何使用布林掩碼根據條件選擇特定元素。在嘗試訪問元素之前,瞭解張量的形狀和要訪問的元素的位置非常重要。

更新於:2023年8月3日

705 次瀏覽

啟動您的 職業生涯

完成課程獲得認證

開始學習
廣告
© . All rights reserved.