如何在 PyTorch 中計算張量的均值和標準差?


PyTorch 張量類似於 NumPy 陣列。唯一的區別是張量利用 GPU 加速數值計算。張量的**均值**使用**torch.mean()**方法計算。它返回輸入張量中所有元素的均值。我們還可以透過提供合適的軸或維度來按行和按列計算均值。

張量的標準差使用**torch.std()**計算。它返回張量中所有元素的標準差。與**均值**一樣,我們也可以按行或按列計算**標準差**。

步驟

  • 匯入所需的庫。在以下所有 Python 示例中,所需的 Python 庫為**torch**。請確保您已安裝它。

  • 定義一個 PyTorch 張量並列印它。

  • 使用**torch.mean(input, axis)**計算均值。這裡,input 是要計算均值的張量,axis(或**dim**)是維度的列表。將計算出的均值賦值給一個新變數。

  • 使用**torch.std(input, axis)**計算標準差。這裡,input 是**張量**,**axis**(或**dim**)是維度的列表。將計算出的標準差賦值給一個新變數。

  • 列印上面計算出的均值和標準差。

示例 1

以下 Python 程式演示瞭如何計算一維張量的均值和標準差。

# Python program to compute mean and standard
# deviation of a 1D tensor
# import the library
import torch

# Create a tensor
T = torch.Tensor([2.453, 4.432, 0.754, -6.554])
print("T:", T)

# Compute the mean and standard deviation
mean = torch.mean(T)
std = torch.std(T)

# Print the computed mean and standard deviation
print("Mean:", mean)
print("Standard deviation:", std)

輸出

T: tensor([ 2.4530, 4.4320, 0.7540, -6.5540])
Mean: tensor(0.2713)
Standard deviation: tensor(4.7920)

示例 2

以下 Python 程式演示瞭如何在兩個維度上計算二維張量的均值和標準差,即按行和按列計算。

# import necessary library
import torch

# create a 3x4 2D tensor
T = torch.Tensor([[2,4,7,-6],
[7,33,-62,23],
[2,-6,-77,54]])
print("T:\n", T)

# compute the mean and standard deviation
mean = torch.mean(T)
std = torch.std(T)
print("Mean:", mean)
print("Standard deviation:", std)

# Compute column-wise mean and std
mean = torch.mean(T, axis = 0)
std = torch.std(T, axis = 0)
print("Column-wise Mean:\n", mean)
print("Column-wise Standard deviation:\n", std)

# Compute row-wise mean and std
mean = torch.mean(T, axis = 1)
std = torch.std(T, axis = 1)
print("Row-wise Mean:\n", mean)
print("Row-wise Standard deviation:\n", std)

輸出

T:
tensor([[ 2., 4., 7., -6.],
         [ 7., 33., -62., 23.],
         [ 2., -6., -77., 54.]])
Mean: tensor(-1.5833)
Standard deviation: tensor(36.2703)
Column-wise Mean:
tensor([ 3.6667, 10.3333, -44.0000, 23.6667])
Column-wise Standard deviation:
tensor([ 2.8868, 20.2567, 44.7996, 30.0056])
Row-wise Mean:
tensor([ 1.7500, 0.2500, -6.7500])
Row-wise Standard deviation:
tensor([ 5.5603, 42.8593, 53.8602])

更新於: 2021年11月6日

6K+ 閱讀量

啟動你的 職業生涯

透過完成課程獲得認證

開始學習
廣告

© . All rights reserved.