Python PyTorch 中的 torch.argmax() 方法


為了找到輸入張量中元素最大值的索引,我們可以使用 **torch.argmax()** 函式。它只返回索引,而不是元素值。如果輸入張量有多個最大值,則該函式將返回第一個最大元素的索引。我們可以應用 **torch.argmax()** 函式來計算張量跨維度最大值的索引。

語法

torch.argmax(input)

步驟

我們可以使用以下步驟來查詢輸入張量中所有元素最大值的索引:

  • 匯入所需的庫。在以下所有示例中,所需的 Python 庫為 **torch**。確保您已安裝它。

import torch
  • 定義輸入張量 **input**。

input = torch.randn(3,4)
  • 計算張量 **input** 中所有元素最大值的索引。

indices = torch.argmax(input)
  • 列印上面計算出的帶有索引的張量。

print("Indices:
", indices)

示例 1

# Import the required library
import torch

# define an input tensor
input = torch.tensor([0., -1., 2., 8.])

# print above defined tensor
print("Input Tensor:
", input) # Compute indices of the maximum value indices = torch.argmax(input) # print the indices print("Indices:
", indices)

輸出

Input Tensor:
   tensor([ 0., -1., 2., 8.])
Indices:
   tensor(3)

在上面的 Python 示例中,我們找到了輸入 1D 張量中元素最大值的索引。輸入張量中的最大值為 8,該元素的索引為 3。

示例 2

在這個程式中,我們計算了相對於不同矩陣範數的條件數。

# Import the required library
import torch

# define an input tensor
input = torch.randn(4,4)

# print above defined tensor
print("Input Tensor:
", input) # Compute indices of the maximum value indices = torch.argmax(input) # print the indices print("Indices:
", indices) # Compute indices of the maximum value in dim 0 indices = torch.argmax(input, dim=0) # print the indices print("Indices in dim 0:
", indices) # Compute indices of the maximum value in dim 1 indices = torch.argmax(input, dim=1) # print the indices print("Indices in dim 1:
", indices)

輸出

Input Tensor:
   tensor([[-1.6729, 1.2613, -1.2882, -0.8133],
   [ 0.9192, 0.9301, -0.2372, 0.0162],
   [-0.4669, 0.6604, -0.7982, 0.2621],
   [ 0.6436, 1.0328, 2.4573, 0.0606]])
Indices:
   tensor(14)
Indices in dim 0:
   tensor([1, 0, 3, 2])
Indices in dim 1:
   tensor([1, 1, 1, 2])

在上面的 Python 示例中,我們找到了輸入 2D 張量中元素最大值的索引,分別在不同的維度上。我們使用 **torch.randn()** 方法生成了輸入張量的元素,因此您可能會注意到獲取不同的輸入張量和索引。

更新於: 2022年1月27日

9K+ 瀏覽量

開啟你的 職業生涯

透過完成課程獲得認證

開始學習
廣告

© . All rights reserved.