如何在PyTorch中測量目標和輸入機率之間的二元交叉熵?
我們應用 **BCELoss()** 方法來計算輸入和目標(預測和實際)機率之間的 *二元交叉熵* 損失。 **BCELoss()** 來自 **torch.nn** 模組。它建立一個衡量二元交叉熵損失的標準。它是 **torch.nn** 模組提供的損失函式的一種型別。
損失函式用於透過最小化損失來最佳化深度神經網路。輸入和目標都應該是具有類機率的張量。確保目標在 0 和 1 之間。輸入和目標張量都可以具有任意數量的維度。例如,在自動編碼器中,**BCELoss()** 用於測量重建誤差。
語法
torch.nn.BCELoss()
步驟
要計算二元交叉熵損失,可以按照以下步驟操作:
匯入所需的庫。在以下所有示例中,所需的Python庫是 **torch**。確保您已安裝它。
import torch
建立輸入和目標張量並列印它們。
input = torch.rand(3, 5) target = torch.randn(3, 5).softmax(dim=1)
建立一個標準來衡量二元交叉熵損失。
bce_loss = nn.BCELoss()
計算二元交叉熵損失並列印它。
output = bce_loss(input, target)
print('Binary Cross Entropy Loss:
', output)**注意** - 在以下示例中,我們使用隨機數來生成輸入和目標張量。因此,您可能會得到這些張量的不同值。
示例 1
在下面的Python程式中,我們計算輸入和目標機率之間的二元交叉熵損失。
import torch
import torch.nn as nn
input = torch.rand(6, requires_grad=True)
target = torch.rand(6)
# create a criterion to measure binary cross entropy
bce_loss = nn.BCELoss()
# compute the binary cross entropy
output = bce_loss(input, target)
output.backward()
print('input:
', input)
print('target:\ n ', target)
print('Binary Cross Entropy Loss:
', output)輸出
input: tensor([0.3440, 0.7944, 0.8919, 0.3551, 0.9817, 0.8871], requires_grad=True) target: tensor([0.1639, 0.4745, 0.1537, 0.5444, 0.6933, 0.1129]) Binary Cross Entropy Loss: tensor(1.2200, grad_fn=<BinaryCrossEntropyBackward>)
請注意,輸入和目標張量的元素都在 0 和 1 之間。
示例 2
在這個程式中,我們計算輸入和目標張量之間的BCE損失。兩個張量都是二維的。請注意,對於目標張量,我們使用 **softmax()** 函式使其元素在 0 和 1 之間。
import torch
import torch.nn as nn
input = torch.rand(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
loss = nn.BCELoss()
output = loss(input, target)
output.backward()
print("Input:
",input)
print("Target:
",target)
print("Binary Cross Entropy Loss:
",output)輸出
Input: tensor([[0.5080, 0.5674, 0.1960, 0.7617, 0.9675], [0.8497, 0.4167, 0.4464, 0.6646, 0.7448], [0.4477, 0.6700, 0.0358, 0.8317, 0.9484]], requires_grad=True) Target: tensor([[0.0821, 0.2900, 0.1864, 0.1480, 0.2935], [0.1719, 0.3426, 0.0729, 0.3616, 0.0510], [0.1284, 0.1542, 0.1338, 0.1779, 0.4057]]) Cross Entropy Loss: tensor(1.0689, grad_fn=<BinaryCrossEntropyBackward>)
請注意,輸入和目標張量的元素都在 0 和 1 之間。
廣告
資料結構
網路
關係資料庫管理系統(RDBMS)
作業系統
Java
iOS
HTML
CSS
Android
Python
C語言程式設計
C++
C#
MongoDB
MySQL
Javascript
PHP