PyTorch中的Tensor.detach()的作用是什麼?


Tensor.detach()用於從當前計算圖中分離張量。它將返回一個不需要梯度的張量。

  • 當不需要跟蹤張量進行梯度計算時,我們會將張量從當前計算圖中分離出來。

  • 當我們需要將張量從GPU傳輸到CPU時,我們也需要分離一個張量。

語法

Tensor.detach()

它將返回一個新的張量,且requires_grad = True。將不再計算與此張量有關的梯度。

步驟

  • 匯入torch庫。確保你已安裝該庫。

import torch
  • 使用requires_grad = True建立PyTorch張量並列印張量。

x = torch.tensor(2.0, requires_grad = True)
print("x:", x)
  • 計算Tensor.detach()並選擇性地將此值賦給新變數。

x_detach = x.detach()
  • 在執行.detach()操作後列印張量。

print("Tensor with detach:", x_detach)

樣例1

# import torch library
import torch

# create a tensor with requires_gradient=true
x = torch.tensor(2.0, requires_grad = True)

# print the tensor
print("Tensor:", x)

# tensor.detach operation
x_detach = x.detach()
print("Tensor with detach:", x_detach)

輸出

Tensor: tensor(2., requires_grad=True)
Tensor with detach: tensor(2.)

請注意,在以上輸出中,detach後的張量沒有requires_grad = True

樣例2

# import torch library
import torch

# define a tensor with requires_grad=true
x = torch.rand(3, requires_grad = True)
print("x:", x)

# apply above tensor to use detach()
y = 3 + x
z = 3 * x.detach()

print("y:", y)
print("z:", z)

輸出

x: tensor([0.5656, 0.8402, 0.6661], requires_grad=True)
y: tensor([3.5656, 3.8402, 3.6661], grad_fn=<AddBackward0>)
z: tensor([1.6968, 2.5207, 1.9984])

更新於: 06-Dec-2021

已瀏覽超過12K次

開啟您的 職業

完成課程獲得認證

開始吧
廣告
© . All rights reserved.