如何在PyTorch中計算兩個向量的成對距離?
PyTorch中的向量是一維張量。為了計算兩個向量之間的成對距離,我們可以使用**PairwiseDistance()**函式。它使用**p-範數**來計算成對距離。**PairwiseDistance**基本上是**torch.nn**模組提供的類。
兩個向量的尺寸必須相同。
可以針對實值和復值輸入計算成對距離。
向量必須為**[N,D]**形狀,其中**N**是批次維度,**D**是向量維度。
語法
torch.nn.PairwiseDistance(p=2)
預設的**p**設定為2。
步驟
您可以使用以下步驟來計算兩個向量之間的成對距離
匯入所需的庫。在以下所有示例中,所需的Python庫是**torch**。確保您已經安裝了它。
import torch
定義兩個向量或兩批向量並列印它們。您可以定義實值或復值張量。
v1 = torch.randn(3,4) v2 = torch.randn(3,4)
建立一個**PairwiseDistance**例項來計算兩個向量之間的成對距離。
pdist = torch.nn.PairwiseDistance(p=2)
計算上面定義的向量之間的成對距離。
output = pdist (v1, v2)
列印包含成對距離值的計算張量。
print("Pairwise Distance:", output)示例1
在這個程式中,我們計算兩個一維向量之間的成對距離。請注意,我們已經對向量進行了unsqueeze操作以使其成為批處理。
# python3 program to compute pairwise distance between
# the two 1D vectors/ tensors
import torch
# define first vector
v1 = torch.tensor([1.,2.,3.,4.]) # size is [4]
# unsqueeze v1 to make it of size [1,4]
v1 = torch.unsqueeze(v1,0)
print("Size of v1:",v1.size())
# define and unsqueeze second vector
v2 = torch.tensor([2.,3.,4.,5.])
v2 = torch.unsqueeze(v2, 0)
print("Size of v2:",v2.size())
print("Vector v1:
", v1)
print("Vector v2:
", v2)
# create an instance of the PairwiseDistance
pdist = torch.nn.PairwiseDistance(p=2)
# compute the distance
output = pdist(v1, v2)
# display the distance
print("Pairwise Distance:
",output)輸出
Size of v1: torch.Size([1, 4]) Size of v2: torch.Size([1, 4]) Vector v1: tensor([[1., 2., 3., 4.]]) Vector v2: tensor([[2., 3., 4., 5.]]) Pairwise Distance: tensor([2.0000])
示例2
在這個程式中,我們計算兩批一維向量之間的成對距離。
# python3 program to compute pairwise distance between
# a batch vectors/ tensors
import torch
# define first batch of 3 vectors
v1 = torch.rand(3,4)
print("Size of v1:",v1.size())
# define second batch of 3 vectors
v2 = torch.rand(3,4)
print("Size of v2:",v2.size())
print("Vector v1:
", v1)
print("Vector v2:
", v2)
# define function to compute pairwise distance
pdist = torch.nn.PairwiseDistance(p=2)
# compute the distance
output = pdist(v1, v2)
# display the distances
print("Pairwise Distance:
",output)輸出
Size of v1: torch.Size([3, 4]) Size of v2: torch.Size([3, 4]) Vector v1: tensor([[0.7245, 0.7953, 0.6502, 0.9976], [0.1185, 0.6365, 0.3543, 0.3417], [0.7827, 0.3520, 0.5634, 0.0534]]) Vector v2: tensor([[0.6419, 0.2966, 0.4424, 0.6570], [0.5991, 0.4173, 0.5387, 0.1531], [0.8377, 0.6622, 0.8260, 0.8249]]) Pairwise Distance: tensor([0.6441, 0.5904, 0.8737])
廣告
資料結構
網路
關係資料庫管理系統 (RDBMS)
作業系統
Java
iOS
HTML
CSS
Android
Python
C語言程式設計
C++
C#
MongoDB
MySQL
Javascript
PHP