PyTorch 中的 backward() 有什麼作用?


backward() 方法用於在神經網路的反向傳播中計算梯度。

  • 執行此方法時將計算梯度。

  • 這些梯度將儲存在相應的變數中。

  • 梯度相對於這些變數計算,而梯度可透過 .grad 進行訪問。

  • 如果不呼叫 backward() 方法來計算梯度,則不會計算梯度。

  • 如果我們使用 .grad 訪問梯度,則結果為

我們舉幾個例子來說明它的工作原理。

示例 1

在此示例中,我們嘗試在不呼叫 backward() 方法的情況下訪問梯度。我們注意到所有的梯度都是

# import torch library
import torch

# define three tensor
x = torch.tensor(2., requires_grad = False)
w = torch.tensor(3., requires_grad = True)
b = torch.tensor(1., requires_grad = True)
print("x:", x)
print("w:", w)
print("b:", b)

# define a function of the above defined tensors
y = w * x + b
print("y:", y)

# print the gradient w.r.t above tensors
print("x.grad:", x.grad)
print("w.grad:", w.grad)
print("b.grad:", b.grad)

輸出

x: tensor(2.)
w: tensor(3., requires_grad=True)
b: tensor(1., requires_grad=True)
y: tensor(7., grad_fn=<AddBackward0>)
x.grad: None
w.grad: None
b.grad: None

示例 2

在第二個示例中,呼叫了函式 ybackward() 方法。然後,訪問了梯度。對於不需要grad的張量,相對於它們的梯度仍然是。但對於需要梯度的張量,相對於它們的梯度並非無。

# import torch library
import torch

# define three tensors
x = torch.tensor(2., requires_grad = False)
w = torch.tensor(3., requires_grad = True)
b = torch.tensor(1., requires_grad = True)
print("x:", x)
print("w:", w)
print("b:", b)

# define a function y
y = w * x + b
print("y:", y)

# take the backward() for y
y.backward()
# print the gradients w.r.t. above x, w, and b
print("x.grad:", x.grad)
print("w.grad:", w.grad)
print("b.grad:", b.grad)

輸出

x: tensor(2.)
w: tensor(3., requires_grad=True)
b: tensor(1., requires_grad=True)
y: tensor(7., grad_fn=<AddBackward0>)
x.grad: None
w.grad: tensor(2.)
b.grad: tensor(1.)

更新於: 06-12-2021

3K+ 次瀏覽

開始你的 職業生涯

完成課程獲得認證

開始學習
廣告
© . All rights reserved.