PyTorch 中的“with torch no_grad”的作用是什麼?
“with torch.no_grad()” 的使用就像一個迴圈,其中迴圈內的每一個張量都將 requires_grad 設定為 False。這意味著當前與當前計算圖關聯的任何帶梯度的張量現在都從當前圖中分離出來。我們不再能夠計算關於此張量的梯度。
一個張量一直從當前圖中分離,直到它在迴圈中。一旦它脫離了迴圈,如果使用梯度定義了張量,就會再次將它附加到當前圖中。
我們來舉幾個例子,以更好地理解它是如何工作的。
示例 1
在這個示例中,我們建立了一個張量 x,其 requires_grad = true。接下來,我們定義這個張量 x 的函式 y,並將函式置於 torch.no_grad() 迴圈中。現在 x 在迴圈中,所以它的 requires_grad 被設定為 False。
在迴圈中,不能針對 x 計算 y 的梯度。所以,y.requires_grad 返回 False。
# import torch library import torch # define a torch tensor x = torch.tensor(2., requires_grad = True) print("x:", x) # define a function y with torch.no_grad(): y = x ** 2 print("y:", y) # check gradient for Y print("y.requires_grad:", y.requires_grad)
輸出
x: tensor(2., requires_grad=True) y: tensor(4.) y.requires_grad: False
示例 2
在此示例中,我們在迴圈外定義了函式 z。所以,z.requires_grad 返回 True。
# import torch library import torch # define three tensors x = torch.tensor(2., requires_grad = False) w = torch.tensor(3., requires_grad = True) b = torch.tensor(1., requires_grad = True) print("x:", x) print("w:", w) print("b:", b) # define a function y y = w * x + b print("y:", y) # define a function z with torch.no_grad(): z = w * x + b print("z:", z) # check if requires grad is true or not print("y.requires_grad:", y.requires_grad) print("z.requires_grad:", z.requires_grad)
輸出
x: tensor(2.) w: tensor(3., requires_grad=True) b: tensor(1., requires_grad=True) y: tensor(7., grad_fn=<AddBackward0>) z: tensor(7.) y.requires_grad: True z.requires_grad: False
廣告