如何在 PyTorch 中逐元素應用修正線性單元函式?
要對輸入張量逐元素應用修正線性單元 (ReLU) 函式,我們使用 **torch.nn.ReLU()**。它將輸入張量中所有負元素替換為 0(零),所有非負元素保持不變。它僅支援實值輸入張量。**ReLU** 用作神經網路中的啟用函式。
語法
relu = torch.nn.ReLU() output = relu(input)
步驟
您可以使用以下步驟逐元素應用修正線性單元 (ReLU) 函式:
匯入所需的庫。在以下所有示例中,所需的 Python 庫為 **torch**。確保您已安裝它。
import torch import torch.nn as nn
定義 **輸入** 張量並列印它。
input = torch.randn(2,3)
print("Input Tensor:
",input)使用 **torch.nn.ReLU()** 定義 ReLU 函式 **relu**。
relu = torch.nn.ReLU()
將上面定義的 ReLU 函式 **relu** 應用於輸入張量。並可以選擇將輸出分配給一個新變數
output = relu(input)
列印包含 ReLU 函式值的張量。
print("ReLU Tensor:
",output)讓我們看幾個例子,以便更好地理解它的工作原理。
示例 1
# Import the required library
import torch
import torch.nn as nn
relu = torch.nn.ReLU()
input = torch.tensor([[-1., 8., 1., 13., 9.],
[ 0., 1., 0., 5., -5.],
[ 3., -5., 8., -1., 5.],
[ 0., 3., -1., 13., 12.]])
print("Input Tensor:
",input)
print("Size of Input Tensor:
",input.size())
# Compute the rectified linear unit (ReLU) function element-wise
output = relu(input)
print("ReLU Tensor:
",output)
print("Size of ReLU Tensor:
",output.size())輸出
Input Tensor: tensor([[-1., 8., 1., 13., 9.], [ 0., 1., 0., 5., -5.], [ 3., -5., 8., -1., 5.], [ 0., 3., -1., 13., 12.]]) Size of Input Tensor: torch.Size([4, 5]) ReLU Tensor: tensor([[ 0., 8., 1., 13., 9.], [ 0., 1., 0., 5., 0.], [ 3., 0., 8., 0., 5.], [ 0., 3., 0., 13., 12.]]) Size of ReLU Tensor: torch.Size([4, 5])
在上面的示例中,請注意輸出張量中輸入張量的負元素被替換為零。
示例 2
# Import the required library
import torch
import torch.nn as nn
relu = torch.nn.ReLU(inplace=True)
input = torch.randn(4,5)
print("Input Tensor:
",input)
print("Size of Input Tensor:
",input.size())
# Compute the rectified linear unit (ReLU) function element-wise
output = relu(input)
print("ReLU Tensor:
",output)
print("Size of ReLU Tensor:
",output.size())輸出
Input Tensor: tensor([[ 0.4217, 0.4151, 1.3292, -1.3835, -0.0086], [-0.7693, -1.7736, -0.3401, -0.7179, -0.0196], [ 1.0918, -0.9426, 2.1496, -0.4809, -1.2254], [-0.3198, -0.2231, 1.2043, 1.1222, 0.7905]]) Size of Input Tensor: torch.Size([4, 5]) ReLU Tensor: tensor([[0.4217, 0.4151, 1.3292, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [1.0918, 0.0000, 2.1496, 0.0000, 0.0000], [0.0000, 0.0000, 1.2043, 1.1222, 0.7905]]) Size of ReLU Tensor: torch.Size([4, 5])
廣告
資料結構
網路
關係資料庫管理系統
作業系統
Java
iOS
HTML
CSS
Android
Python
C 程式設計
C++
C#
MongoDB
MySQL
Javascript
PHP