如何在PyTorch中計算輸入和目標張量之間的交叉熵損失?


為了計算輸入和目標(預測值和實際值)之間的交叉熵損失,我們應用函式 **CrossEntropyLoss()**。它可以從 **torch.nn** 模組訪問。它建立一個衡量交叉熵損失的標準。它是 **torch.nn** 模組提供的損失函式的一種型別。

損失函式用於透過最小化損失來最佳化深度神經網路。**CrossEntropyLoss()** 在訓練多類分類問題中非常有用。輸入預計包含每個類別的未歸一化分數。

目標張量可能包含範圍在 **[0,C-1]** 內的類索引,其中 **C** 是類的數量或類機率。

語法

torch.nn.CrossEntropyLoss()

步驟

為了計算交叉熵損失,可以遵循以下步驟

  • 匯入所需的庫。在以下所有示例中,所需的 Python 庫是 **torch**。確保您已安裝它。

import torch
  • 建立輸入和目標張量並列印它們。

input = torch.rand(3, 5)
target = torch.empty(3, dtype = torch.long).random_(5)
  • 建立一個標準來衡量交叉熵損失。

loss = nn.CrossEntropyLoss()
  • 計算交叉熵損失並列印它。

output = loss(input, target)
print('Cross Entropy Loss: 
', output)

**注意** - 在以下示例中,我們使用隨機數來生成輸入和目標張量。因此,您可能會注意到這些張量的值不同。

示例 1

在這個示例中,我們計算輸入和目標張量之間的交叉熵損失。這裡我們以具有類索引的目標張量為例。

# Example of target with class indices
import torch
import torch.nn as nn

input = torch.rand(3, 5)
target = torch.empty(3, dtype = torch.long).random_(5)
print(target)

loss = nn.CrossEntropyLoss()
output = loss(input, target)

print('input:
', input) print('target:
', target) print('Cross Entropy Loss:
', output)

輸出

tensor([2, 0, 4])
input:
   tensor([[0.2228, 0.2523, 0.9712, 0.7887, 0.2820],
      [0.7778, 0.4144, 0.8693, 0.1355, 0.3706],
      [0.0823, 0.5392, 0.0542, 0.0153, 0.8475]])
target:
   tensor([2, 0, 4])
Cross Entropy Loss:
   tensor(1.2340)

示例 2

在這個示例中,我們計算輸入和目標張量之間的交叉熵損失。這裡我們以具有類機率的目標張量為例。

# Example of target with class probabilities
import torch
import torch.nn as nn

input = torch.rand(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
print(target.size())
loss = nn.CrossEntropyLoss()
output = loss(input, target)
output.backward()
print("Input:
",input) print("Target:
",target) print("Cross Entropy Loss:
",output) print('Input grads:
', input.grad)

輸出

torch.Size([3])
Input:
   tensor([[0.8671, 0.0189, 0.0042, 0.1619, 0.9805],
      [0.1054, 0.1519, 0.6359, 0.6112, 0.9417],
      [0.9968, 0.3285, 0.9185, 0.0315, 0.9592]],
      requires_grad=True)
Target:
   tensor([1, 0, 4])
Cross Entropy Loss:
   tensor(1.8338, grad_fn=<NllLossBackward>)
Input grads:
   tensor([[ 0.0962, -0.2921, 0.0406, 0.0475, 0.1078],
      [-0.2901, 0.0453, 0.0735, 0.0717, 0.0997],
      [ 0.0882, 0.0452, 0.0815, 0.0336, -0.2484]])

更新於: 2022年1月20日

9K+ 次瀏覽

啟動您的 職業生涯

透過完成課程獲得認證

開始學習
廣告
© . All rights reserved.