Python PyTorch 中的 torch.argmax() 方法
為了找到輸入張量中元素最大值的索引,我們可以使用 **torch.argmax()** 函式。它只返回索引,而不是元素值。如果輸入張量有多個最大值,則該函式將返回第一個最大元素的索引。我們可以應用 **torch.argmax()** 函式來計算張量跨維度最大值的索引。
語法
torch.argmax(input)
步驟
我們可以使用以下步驟來查詢輸入張量中所有元素最大值的索引:
匯入所需的庫。在以下所有示例中,所需的 Python 庫為 **torch**。確保您已安裝它。
import torch
定義輸入張量 **input**。
input = torch.randn(3,4)
計算張量 **input** 中所有元素最大值的索引。
indices = torch.argmax(input)
列印上面計算出的帶有索引的張量。
print("Indices:
", indices)示例 1
# Import the required library
import torch
# define an input tensor
input = torch.tensor([0., -1., 2., 8.])
# print above defined tensor
print("Input Tensor:
", input)
# Compute indices of the maximum value
indices = torch.argmax(input)
# print the indices
print("Indices:
", indices)輸出
Input Tensor: tensor([ 0., -1., 2., 8.]) Indices: tensor(3)
在上面的 Python 示例中,我們找到了輸入 1D 張量中元素最大值的索引。輸入張量中的最大值為 8,該元素的索引為 3。
示例 2
在這個程式中,我們計算了相對於不同矩陣範數的條件數。
# Import the required library
import torch
# define an input tensor
input = torch.randn(4,4)
# print above defined tensor
print("Input Tensor:
", input)
# Compute indices of the maximum value
indices = torch.argmax(input)
# print the indices
print("Indices:
", indices)
# Compute indices of the maximum value in dim 0
indices = torch.argmax(input, dim=0)
# print the indices
print("Indices in dim 0:
", indices)
# Compute indices of the maximum value in dim 1
indices = torch.argmax(input, dim=1)
# print the indices
print("Indices in dim 1:
", indices)輸出
Input Tensor: tensor([[-1.6729, 1.2613, -1.2882, -0.8133], [ 0.9192, 0.9301, -0.2372, 0.0162], [-0.4669, 0.6604, -0.7982, 0.2621], [ 0.6436, 1.0328, 2.4573, 0.0606]]) Indices: tensor(14) Indices in dim 0: tensor([1, 0, 3, 2]) Indices in dim 1: tensor([1, 1, 1, 2])
在上面的 Python 示例中,我們找到了輸入 2D 張量中元素最大值的索引,分別在不同的維度上。我們使用 **torch.randn()** 方法生成了輸入張量的元素,因此您可能會注意到獲取不同的輸入張量和索引。
廣告
資料結構
網路
關係型資料庫管理系統
作業系統
Java
iOS
HTML
CSS
Android
Python
C 語言程式設計
C++
C#
MongoDB
MySQL
Javascript
PHP