Python – PyTorch clamp() 方法
**torch.clamp()** 用於將輸入張量中的所有元素限制在 **[min, max]** 範圍內。它接受三個引數:**輸入**張量、**min** 值和 **max** 值。小於 min 的值將被替換為 **min**,大於 max 的值將被替換為 **max**。
如果未指定 **min**,則沒有下界。如果未指定 **max**,則沒有上界。例如,如果我們設定 **min=-0.5** 和 **max=0.4**,則小於 -0.5 的值將被替換為 -0.5,大於 0.4 的值將被替換為 0.4。介於這兩個值之間的值將保持不變。它只支援實數值輸入。
語法
torch.clamp(input, min=None, max=None)
引數
**input** - 輸入張量。
**min** - 下界;可以是數字或張量。
**max** - 上界;可以是數字或張量。
它返回一個新的張量,其中所有元素都被限制在 **[min, max]** 範圍內。
步驟
匯入所需的庫。在以下所有示例中,所需的 Python 庫是 **torch**。確保您已安裝它。
import torch
建立一個輸入張量並列印它。
a = torch.tensor([0.73, 0.35, -0.39, -1.53])
print("input tensor:
", a)限制輸入張量的元素。這裡我們使用 **min=-0.5, max=0.5**。
t1 = torch.clamp(a, min=-0.5, max=0.5)
列印 clamp 之後獲得的張量。
print(t1)
示例 1
在下面的 Python 程式中,我們限制一維輸入張量的元素。請注意當 **min** 或 **max** 為 **None** 時,**clamp()** 方法是如何工作的。
# Import the required library
import torch
# define a 1D tensor
a = torch.tensor([ 0.73, 0.35, -0.39, -1.53])
print("input tensor:
", a)
print("clamp the tensor:")
print("into range [-0.5, 0.5]:")
t1 = torch.clamp(a, min=-0.5, max=0.5)
print(t1)
print("if min is None:")
t2 = torch.clamp(a, max=0.5)
print(t2)
print("if max is None:")
t3 = torch.clamp(a, min=0.5)
print(t3)
print("if min is greater than max:")
t4 = torch.clamp(a, min=0.6, max=.5)
print(t4)輸出
input tensor: tensor([ 0.7300, 0.3500, -0.3900, -1.5300]) clamp the tensor: into range [-0.5, 0.5]: tensor([ 0.5000, 0.3500, -0.3900, -0.5000]) if min is None: tensor([ 0.5000, 0.3500, -0.3900, -1.5300]) if max is None: tensor([0.7300, 0.5000, 0.5000, 0.5000]) if min is greater than max: tensor([0.5000, 0.5000, 0.5000, 0.5000])
示例 2
在下面的 Python 程式中,我們限制二維輸入張量的元素。請注意當 **min** 或 **max** 為 **None** 時,**clamp()** 方法是如何工作的。
# Import the required library
import torch
# define a 2D tensor of size [3, 4]
a = torch.randn(3,4)
print("input tensor:
", a)
print("clamp the tensor:")
print("into range [-0.6, 0.4]:")
t1 = torch.clamp(a, min=-0.6, max=0.4)
print(t1)
print("if min is None (max=0.4):")
t2 = torch.clamp(a, max=0.4)
print(t2)
print("if max is None (min=-0.6):")
t3 = torch.clamp(a, min=-0.6)
print(t3)
print("if min is greater than max (min=0.6, max=0.4):")
t4 = torch.clamp(a, min=0.6, max=0.4)
print(t4)輸出
input tensor: tensor([[ 1.2133, 0.2199, -0.0864, -0.1143], [ 0.4205, 1.0258, 0.4022, -1.3172], [ 1.5405, 0.8545, 0.7009, 0.5874]]) clamp the tensor: into range [-0.6, 0.4]: tensor([[ 0.4000, 0.2199, -0.0864, -0.1143], [ 0.4000, 0.4000, 0.4000, -0.6000], [ 0.4000, 0.4000, 0.4000, 0.4000]]) if min is None (max=0.4): tensor([[ 0.4000, 0.2199, -0.0864, -0.1143], [ 0.4000, 0.4000, 0.4000, -1.3172], [ 0.4000, 0.4000, 0.4000, 0.4000]]) if max is None (min=-0.6): tensor([[ 1.2133, 0.2199, -0.0864, -0.1143], [ 0.4205, 1.0258, 0.4022, -0.6000], [ 1.5405, 0.8545, 0.7009, 0.5874]]) if min is greater than max (min=0.6, max=0.4): tensor([[0.4000, 0.4000, 0.4000, 0.4000], [0.4000, 0.4000, 0.4000, 0.4000], [0.4000, 0.4000, 0.4000, 0.4000]])
廣告
資料結構
網路
關係資料庫管理系統 (RDBMS)
作業系統
Java
iOS
HTML
CSS
Android
Python
C語言程式設計
C++
C#
MongoDB
MySQL
Javascript
PHP