如何在 PyTorch 中计算张量的均值和标准差?


PyTorch 张量类似于 NumPy 数组。唯一的区别是张量利用 GPU 加速数值计算。张量的**均值**使用**torch.mean()**方法计算。它返回输入张量中所有元素的均值。我们还可以通过提供合适的轴或维度来按行和按列计算均值。

张量的标准差使用**torch.std()**计算。它返回张量中所有元素的标准差。与**均值**一样,我们也可以按行或按列计算**标准差**。

步骤

  • 导入所需的库。在以下所有 Python 示例中,所需的 Python 库为**torch**。请确保您已安装它。

  • 定义一个 PyTorch 张量并打印它。

  • 使用**torch.mean(input, axis)**计算均值。这里,input 是要计算均值的张量,axis(或**dim**)是维度的列表。将计算出的均值赋值给一个新变量。

  • 使用**torch.std(input, axis)**计算标准差。这里,input 是**张量**,**axis**(或**dim**)是维度的列表。将计算出的标准差赋值给一个新变量。

  • 打印上面计算出的均值和标准差。

示例 1

以下 Python 程序演示了如何计算一维张量的均值和标准差。

# Python program to compute mean and standard
# deviation of a 1D tensor
# import the library
import torch

# Create a tensor
T = torch.Tensor([2.453, 4.432, 0.754, -6.554])
print("T:", T)

# Compute the mean and standard deviation
mean = torch.mean(T)
std = torch.std(T)

# Print the computed mean and standard deviation
print("Mean:", mean)
print("Standard deviation:", std)

输出

T: tensor([ 2.4530, 4.4320, 0.7540, -6.5540])
Mean: tensor(0.2713)
Standard deviation: tensor(4.7920)

示例 2

以下 Python 程序演示了如何在两个维度上计算二维张量的均值和标准差,即按行和按列计算。

# import necessary library
import torch

# create a 3x4 2D tensor
T = torch.Tensor([[2,4,7,-6],
[7,33,-62,23],
[2,-6,-77,54]])
print("T:\n", T)

# compute the mean and standard deviation
mean = torch.mean(T)
std = torch.std(T)
print("Mean:", mean)
print("Standard deviation:", std)

# Compute column-wise mean and std
mean = torch.mean(T, axis = 0)
std = torch.std(T, axis = 0)
print("Column-wise Mean:\n", mean)
print("Column-wise Standard deviation:\n", std)

# Compute row-wise mean and std
mean = torch.mean(T, axis = 1)
std = torch.std(T, axis = 1)
print("Row-wise Mean:\n", mean)
print("Row-wise Standard deviation:\n", std)

输出

T:
tensor([[ 2., 4., 7., -6.],
         [ 7., 33., -62., 23.],
         [ 2., -6., -77., 54.]])
Mean: tensor(-1.5833)
Standard deviation: tensor(36.2703)
Column-wise Mean:
tensor([ 3.6667, 10.3333, -44.0000, 23.6667])
Column-wise Standard deviation:
tensor([ 2.8868, 20.2567, 44.7996, 30.0056])
Row-wise Mean:
tensor([ 1.7500, 0.2500, -6.7500])
Row-wise Standard deviation:
tensor([ 5.5603, 42.8593, 53.8602])

更新于: 2021年11月6日

6K+ 阅读量

启动你的 职业生涯

通过完成课程获得认证

开始学习
广告