如何在 PyTorch 中查詢張量的第 k 個元素和前 k 個元素?


PyTorch 提供了一個方法 **torch.kthvalue()** 來查詢張量的第 k 個元素。它返回按升序排序的張量中第 k 個元素的值,以及該元素在原始張量中的索引。

**torch.topk()** 方法用於查詢前 k 個元素。它返回張量中前 k 個或最大的 k 個元素。

步驟

  • 匯入所需的庫。在以下所有 Python 示例中,所需的 Python 庫是 **torch**。確保您已安裝它。

  • 建立一個 PyTorch 張量並列印它。

  • 計算 **torch.kthvalue(input, k)**。它返回兩個張量。將這兩個張量賦值給兩個新的變數 **"value"** 和 **"index"**。這裡,input 是一個張量,k 是一個整數。

  • 計算 **torch.topk(input, k)**。它返回兩個張量。第一個張量包含前 k 個元素的值,第二個張量包含這些元素在原始張量中的索引。將這兩個張量賦值給新的變數 **"values"** 和 **"indices"**。

  • 列印張量中第 k 個元素的值和索引,以及張量中前 k 個元素的值和索引。

示例 1

此 Python 程式演示如何查詢張量的第 k 個元素。

# Python program to find k-th element of a tensor
# import necessary library
import torch

# Create a 1D tensor
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)

# Find the 3rd element in sorted tensor. First it sorts the
# tensor in ascending order then returns the kth element value
# from sorted tensor and the index of element in original tensor
value, index = torch.kthvalue(T, 3)

# print 3rd element with value and index
print("3rd element value:", value)
print("3rd element index:", index)

輸出

Original Tensor:
   tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])
3rd element value: tensor(2.3340)
3rd element index: tensor(0)

示例 2

以下 Python 程式演示如何查詢張量的前 k 個或最大的 k 個元素。

# Python program to find to top k elements of a tensor
# import necessary library
import torch

# Create a 1D tensor
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)

# Find the top k=2 or 2 largest elements of the tensor
# returns the 2 largest values and their indices in original
# tensor
values, indices = torch.topk(T, 2)

# print top 2 elements with value and index
print("Top 2 element values:", values)
print("Top 2 element indices:", indices)

輸出

Original Tensor:
   tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])
Top 2 element values: tensor([5.0000, 4.4430])
Top 2 element indices: tensor([4, 5])

更新於:2021年11月6日

782 次瀏覽

啟動您的 職業生涯

完成課程獲得認證

開始
廣告
© . All rights reserved.