PyTorch 中的“with torch no_grad”的作用是什麼?


“with torch.no_grad()” 的使用就像一個迴圈,其中迴圈內的每一個張量都將 requires_grad 設定為 False。這意味著當前與當前計算圖關聯的任何帶梯度的張量現在都從當前圖中分離出來。我們不再能夠計算關於此張量的梯度。

一個張量一直從當前圖中分離,直到它在迴圈中。一旦它脫離了迴圈,如果使用梯度定義了張量,就會再次將它附加到當前圖中。

我們來舉幾個例子,以更好地理解它是如何工作的。

示例 1

在這個示例中,我們建立了一個張量 x,其 requires_grad = true。接下來,我們定義這個張量 x 的函式 y,並將函式置於 torch.no_grad() 迴圈中。現在 x 在迴圈中,所以它的 requires_grad 被設定為 False

在迴圈中,不能針對 x 計算 y 的梯度。所以,y.requires_grad 返回 False

# import torch library
import torch

# define a torch tensor
x = torch.tensor(2., requires_grad = True)
print("x:", x)

# define a function y
with torch.no_grad():
   y = x ** 2
print("y:", y)

# check gradient for Y
print("y.requires_grad:", y.requires_grad)

輸出

x: tensor(2., requires_grad=True)
y: tensor(4.)
y.requires_grad: False

示例 2

在此示例中,我們在迴圈外定義了函式 z。所以,z.requires_grad 返回 True

# import torch library
import torch

# define three tensors
x = torch.tensor(2., requires_grad = False)
w = torch.tensor(3., requires_grad = True)
b = torch.tensor(1., requires_grad = True)

print("x:", x)
print("w:", w)
print("b:", b)

# define a function y
y = w * x + b
print("y:", y)

# define a function z
with torch.no_grad():
   z = w * x + b

print("z:", z)

# check if requires grad is true or not
print("y.requires_grad:", y.requires_grad)
print("z.requires_grad:", z.requires_grad)

輸出

x: tensor(2.)
w: tensor(3., requires_grad=True)
b: tensor(1., requires_grad=True)
y: tensor(7., grad_fn=<AddBackward0>)
z: tensor(7.)
y.requires_grad: True
z.requires_grad: False

更新日期:06-12-2021

6K+ 瀏覽

啟動你的 職業生涯

完成課程並獲得認證

開始
廣告