如何在PyTorch中對張量的元素進行排序?
為了對PyTorch中的張量元素進行排序,我們可以使用`torch.sort()`方法。此方法返回兩個張量。第一個張量是包含元素排序值的張量,第二個張量是原始張量中元素索引的張量。我們可以計算2D張量的行排序和列排序。
步驟
匯入所需的庫。在以下所有Python示例中,所需的Python庫是**torch**。確保您已安裝它。
建立一個PyTorch張量並列印它。
要對上面建立的張量的元素進行排序,請計算`**torch.sort(input, dim)**`。將此值賦給新變數`"v"`。這裡,`**input**`是輸入張量,`**dim**`是沿其對元素進行排序的維度。要對元素進行行排序,`dim`設定為1;要對元素進行列排序,`dim`設定為0。
包含排序值的張量可以訪問為`**v[0]**`,排序元素的索引張量可以訪問為`**v[1]**`。
列印包含排序值的張量和包含排序值索引的張量。
示例1
下面的Python程式展示瞭如何對一維張量的元素進行排序。
# Python program to sort elements of a tensor
# import necessary library
import torch
# Create a tensor
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)
# sort the tensor T
# it sorts the tensor in ascending order
v = torch.sort(T)
# print(v)
# print tensor of sorted value
print("Tensor with sorted value:\n", v[0])
# print indices of sorted value
print("Indices of sorted value:\n", v[1])輸出
Original Tensor: tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430]) Tensor with sorted value: tensor([-4.3300, -0.4330, 2.3340, 4.4330, 4.4430, 5.0000]) Indices of sorted value: tensor([2, 3, 0, 1, 5, 4])
示例2
下面的Python程式展示瞭如何對二維張量的元素進行排序。
# Python program to sort elements of a 2-D tensor
# import the library
import torch
# Create a 2-D tensor
T = torch.Tensor([[2,3,-32],
[43,4,-53],
[4,37,-4],
[3,-75,34]])
print("Original Tensor:\n", T)
# sort tensor T
# it sorts the tensor in ascending order
v = torch.sort(T)
# print(v)
# print tensor of sorted value
print("Tensor with sorted value:\n", v[0])
# print indices of sorted value
print("Indices of sorted value:\n", v[1])
print("Sort tensor Column-wise")
v = torch.sort(T, 0)
# print(v)
# print tensor of sorted value
print("Tensor with sorted value:\n", v[0])
# print indices of sorted value
print("Indices of sorted value:\n", v[1])
print("Sort tensor Row-wise")
v = torch.sort(T, 1)
# print(v)
# print tensor of sorted value
print("Tensor with sorted value:\n", v[0])
# print indices of sorted value
print("Indices of sorted value:\n", v[1])輸出
Original Tensor: tensor([[ 2., 3., -32.], [ 43., 4., -53.], [ 4., 37., -4.], [ 3., -75., 34.]]) Tensor with sorted value: tensor([[-32., 2., 3.], [-53., 4., 43.], [ -4., 4., 37.], [-75., 3., 34.]]) Indices of sorted value: tensor([[2, 0, 1], [2, 1, 0], [2, 0, 1], [1, 0, 2]]) Sort tensor Column-wise Tensor with sorted value: tensor([[ 2., -75., -53.], [ 3., 3., -32.], [ 4., 4., -4.], [ 43., 37., 34.]]) Indices of sorted value: tensor([[0, 3, 1], [3, 0, 0], [2, 1, 2], [1, 2, 3]]) Sort tensor Row-wise Tensor with sorted value: tensor([[-32., 2., 3.], [-53., 4., 43.], [ -4., 4., 37.], [-75., 3., 34.]]) Indices of sorted value: tensor([[2, 0, 1], [2, 1, 0], [2, 0, 1], [1, 0, 2]])
廣告
資料結構
網路
關係資料庫管理系統 (RDBMS)
作業系統
Java
iOS
HTML
CSS
Android
Python
C語言程式設計
C++
C#
MongoDB
MySQL
Javascript
PHP