如何在PyTorch中計算兩個張量的餘弦相似度?
為了計算兩個張量之間的餘弦相似度,我們使用**torch.nn**模組提供的**CosineSimilarity()**函式。它返回沿**dim**計算的餘弦相似度值。
**dim**是此函式的一個可選引數,沿其計算餘弦相似度。
對於一維張量,我們只能沿**dim=0**計算餘弦相似度。
對於二維張量,我們可以沿**dim=0**或**1**計算餘弦相似度。
為了計算餘弦相似度,兩個張量的尺寸必須相同。兩個張量必須是實數值的。餘弦相似度常用於文字分析中度量文件相似度。
語法
torch.nn.CosineSimilarity(dim=1)
預設的**dim**設定為1。但是,如果您測量**一維張量**之間的餘弦相似度,則我們將**dim**設定為0。
步驟
匯入所需的庫。在以下所有示例中,所需的Python庫是**torch**。確保您已安裝它。
import torch
建立兩個張量並列印它們。兩個張量必須是實數值的。
tensor1 = torch.randn(3,4) tensor2 = torch.randn(3,4)
定義一個沿維度**dim**測量餘弦相似度的方法。
cos = torch.nn.CosineSimilarity(dim=0)
使用上面定義的方法計算餘弦相似度。
output = cos(tensor1, tensor2)
列印計算出的包含餘弦相似度值的張量。
print("Cosine Similarity:",output)示例1
下面的Python程式計算兩個一維張量之間的**餘弦相似度**。
# Import the required library
import torch
# define two input tensors
tensor1 = torch.tensor([0.1, 0.3, 2.3, 0.45])
tensor2 = torch.tensor([0.13, 0.23, 2.33, 0.45])
# print above defined two tensors
print("Tensor 1:
", tensor1)
print("Tensor 2:
", tensor2)
# define a method to measure cosine similarity
cos = torch.nn.CosineSimilarity(dim=0)
output = cos(tensor1, tensor2)
# display the output tensor
print("Cosine Similarity:",output)輸出
Tensor 1: tensor([0.1000, 0.3000, 2.3000, 0.4500]) Tensor 2: tensor([0.1300, 0.2300, 2.3300, 0.4500]) Cosine Similarity: tensor(0.9995)
示例2
在這個Python程式中,我們沿不同的**dim**計算兩個二維張量之間的餘弦相似度。
# Import the required library
import torch
# define two input tensors
tensor1 = torch.randn(3,4)
tensor2 = torch.randn(3,4)
# print above defined two tensors
print("Tensor 1:
", tensor1)
print("Tensor 2:
", tensor2)
# define a method to measure cosine similarity in dim 0
cos0 = torch.nn.CosineSimilarity(dim=0)
output0 = cos0(tensor1, tensor2)
print("Cosine Similarity in dim 0:
",output0)
# define a method to measure cosine similarity in dim 1
cos1 = torch.nn.CosineSimilarity(dim=1)
output1 = cos1(tensor1, tensor2)
print("Cosine Similarity in dim 1:
",output1)輸出
Tensor 1: tensor([[ 0.2714, 1.1430, 1.3997, 0.8788], [-2.2268, 1.9799, 1.5682, 0.5850], [ 1.2289, 0.5043, -0.1625, 1.1403]]) Tensor 2: tensor([[-0.3299, 0.6360, -0.2014, 0.5989], [-0.6679, 0.0793, -2.5842, -1.5123], [ 1.1110, -0.1212, 0.0324, 1.1277]]) Cosine Similarity in dim 0: tensor([ 0.8076, 0.5388, -0.7941, 0.3016]) Cosine Similarity in dim 1: tensor([ 0.4553, -0.3140, 0.9258])
廣告
資料結構
網路
關係型資料庫管理系統 (RDBMS)
作業系統
Java
iOS
HTML
CSS
Android
Python
C語言程式設計
C++
C#
MongoDB
MySQL
Javascript
PHP