如何在PyTorch中計算輸入和目標張量之間的交叉熵損失?
為了計算輸入和目標(預測值和實際值)之間的交叉熵損失,我們應用函式 **CrossEntropyLoss()**。它可以從 **torch.nn** 模組訪問。它建立一個衡量交叉熵損失的標準。它是 **torch.nn** 模組提供的損失函式的一種型別。
損失函式用於透過最小化損失來最佳化深度神經網路。**CrossEntropyLoss()** 在訓練多類分類問題中非常有用。輸入預計包含每個類別的未歸一化分數。
目標張量可能包含範圍在 **[0,C-1]** 內的類索引,其中 **C** 是類的數量或類機率。
語法
torch.nn.CrossEntropyLoss()
步驟
為了計算交叉熵損失,可以遵循以下步驟
匯入所需的庫。在以下所有示例中,所需的 Python 庫是 **torch**。確保您已安裝它。
import torch
建立輸入和目標張量並列印它們。
input = torch.rand(3, 5) target = torch.empty(3, dtype = torch.long).random_(5)
建立一個標準來衡量交叉熵損失。
loss = nn.CrossEntropyLoss()
計算交叉熵損失並列印它。
output = loss(input, target)
print('Cross Entropy Loss:
', output)**注意** - 在以下示例中,我們使用隨機數來生成輸入和目標張量。因此,您可能會注意到這些張量的值不同。
示例 1
在這個示例中,我們計算輸入和目標張量之間的交叉熵損失。這裡我們以具有類索引的目標張量為例。
# Example of target with class indices
import torch
import torch.nn as nn
input = torch.rand(3, 5)
target = torch.empty(3, dtype = torch.long).random_(5)
print(target)
loss = nn.CrossEntropyLoss()
output = loss(input, target)
print('input:
', input)
print('target:
', target)
print('Cross Entropy Loss:
', output)輸出
tensor([2, 0, 4]) input: tensor([[0.2228, 0.2523, 0.9712, 0.7887, 0.2820], [0.7778, 0.4144, 0.8693, 0.1355, 0.3706], [0.0823, 0.5392, 0.0542, 0.0153, 0.8475]]) target: tensor([2, 0, 4]) Cross Entropy Loss: tensor(1.2340)
示例 2
在這個示例中,我們計算輸入和目標張量之間的交叉熵損失。這裡我們以具有類機率的目標張量為例。
# Example of target with class probabilities
import torch
import torch.nn as nn
input = torch.rand(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
print(target.size())
loss = nn.CrossEntropyLoss()
output = loss(input, target)
output.backward()
print("Input:
",input)
print("Target:
",target)
print("Cross Entropy Loss:
",output)
print('Input grads:
', input.grad)輸出
torch.Size([3]) Input: tensor([[0.8671, 0.0189, 0.0042, 0.1619, 0.9805], [0.1054, 0.1519, 0.6359, 0.6112, 0.9417], [0.9968, 0.3285, 0.9185, 0.0315, 0.9592]], requires_grad=True) Target: tensor([1, 0, 4]) Cross Entropy Loss: tensor(1.8338, grad_fn=<NllLossBackward>) Input grads: tensor([[ 0.0962, -0.2921, 0.0406, 0.0475, 0.1078], [-0.2901, 0.0453, 0.0735, 0.0717, 0.0997], [ 0.0882, 0.0452, 0.0815, 0.0336, -0.2484]])
廣告
資料結構
網路
關係資料庫管理系統 (RDBMS)
作業系統
Java
iOS
HTML
CSS
Android
Python
C語言程式設計
C++
C#
MongoDB
MySQL
Javascript
PHP