Python – PyTorch clamp() 方法


**torch.clamp()** 用於將輸入張量中的所有元素限制在 **[min, max]** 範圍內。它接受三個引數:**輸入**張量、**min** 值和 **max** 值。小於 min 的值將被替換為 **min**,大於 max 的值將被替換為 **max**。

如果未指定 **min**,則沒有下界。如果未指定 **max**,則沒有上界。例如,如果我們設定 **min=-0.5** 和 **max=0.4**,則小於 -0.5 的值將被替換為 -0.5,大於 0.4 的值將被替換為 0.4。介於這兩個值之間的值將保持不變。它只支援實數值輸入。

語法

torch.clamp(input, min=None, max=None)

引數

  • **input** - 輸入張量。

  • **min** - 下界;可以是數字或張量。

  • **max** - 上界;可以是數字或張量。

它返回一個新的張量,其中所有元素都被限制在 **[min, max]** 範圍內。

步驟

  • 匯入所需的庫。在以下所有示例中,所需的 Python 庫是 **torch**。確保您已安裝它。

import torch
  • 建立一個輸入張量並列印它。

a = torch.tensor([0.73, 0.35, -0.39, -1.53])
print("input tensor:
", a)
  • 限制輸入張量的元素。這裡我們使用 **min=-0.5, max=0.5**。

t1 = torch.clamp(a, min=-0.5, max=0.5)
  • 列印 clamp 之後獲得的張量。

print(t1)

示例 1

在下面的 Python 程式中,我們限制一維輸入張量的元素。請注意當 **min** 或 **max** 為 **None** 時,**clamp()** 方法是如何工作的。

# Import the required library
import torch

# define a 1D tensor
a = torch.tensor([ 0.73, 0.35, -0.39, -1.53])
print("input tensor:
", a) print("clamp the tensor:") print("into range [-0.5, 0.5]:") t1 = torch.clamp(a, min=-0.5, max=0.5) print(t1) print("if min is None:") t2 = torch.clamp(a, max=0.5) print(t2) print("if max is None:") t3 = torch.clamp(a, min=0.5) print(t3) print("if min is greater than max:") t4 = torch.clamp(a, min=0.6, max=.5) print(t4)

輸出

input tensor:
   tensor([ 0.7300, 0.3500, -0.3900, -1.5300])
clamp the tensor:
into range [-0.5, 0.5]:
   tensor([ 0.5000, 0.3500, -0.3900, -0.5000])
if min is None:
   tensor([ 0.5000, 0.3500, -0.3900, -1.5300])
if max is None:
   tensor([0.7300, 0.5000, 0.5000, 0.5000])
if min is greater than max:
   tensor([0.5000, 0.5000, 0.5000, 0.5000])

示例 2

在下面的 Python 程式中,我們限制二維輸入張量的元素。請注意當 **min** 或 **max** 為 **None** 時,**clamp()** 方法是如何工作的。

# Import the required library
import torch

# define a 2D tensor of size [3, 4]
a = torch.randn(3,4)
print("input tensor:
", a) print("clamp the tensor:") print("into range [-0.6, 0.4]:") t1 = torch.clamp(a, min=-0.6, max=0.4) print(t1) print("if min is None (max=0.4):") t2 = torch.clamp(a, max=0.4) print(t2) print("if max is None (min=-0.6):") t3 = torch.clamp(a, min=-0.6) print(t3) print("if min is greater than max (min=0.6, max=0.4):") t4 = torch.clamp(a, min=0.6, max=0.4) print(t4)

輸出

input tensor:
   tensor([[ 1.2133, 0.2199, -0.0864, -0.1143],
      [ 0.4205, 1.0258, 0.4022, -1.3172],
      [ 1.5405, 0.8545, 0.7009, 0.5874]])
clamp the tensor:
into range [-0.6, 0.4]:
   tensor([[ 0.4000, 0.2199, -0.0864, -0.1143],
      [ 0.4000, 0.4000, 0.4000, -0.6000],
      [ 0.4000, 0.4000, 0.4000, 0.4000]])
if min is None (max=0.4):
   tensor([[ 0.4000, 0.2199, -0.0864, -0.1143],
      [ 0.4000, 0.4000, 0.4000, -1.3172],
      [ 0.4000, 0.4000, 0.4000, 0.4000]])
if max is None (min=-0.6):
   tensor([[ 1.2133, 0.2199, -0.0864, -0.1143],
      [ 0.4205, 1.0258, 0.4022, -0.6000],
      [ 1.5405, 0.8545, 0.7009, 0.5874]])
if min is greater than max (min=0.6, max=0.4):
   tensor([[0.4000, 0.4000, 0.4000, 0.4000],
      [0.4000, 0.4000, 0.4000, 0.4000],
      [0.4000, 0.4000, 0.4000, 0.4000]])

更新於:2022年1月20日

5000+ 瀏覽量

啟動您的 職業生涯

完成課程獲得認證

開始學習
廣告