如何在 PyTorch 中计算张量的均值和标准差?
PyTorch 张量类似于 NumPy 数组。唯一的区别是张量利用 GPU 加速数值计算。张量的**均值**使用**torch.mean()**方法计算。它返回输入张量中所有元素的均值。我们还可以通过提供合适的轴或维度来按行和按列计算均值。
张量的标准差使用**torch.std()**计算。它返回张量中所有元素的标准差。与**均值**一样,我们也可以按行或按列计算**标准差**。
步骤
导入所需的库。在以下所有 Python 示例中,所需的 Python 库为**torch**。请确保您已安装它。
定义一个 PyTorch 张量并打印它。
使用**torch.mean(input, axis)**计算均值。这里,input 是要计算均值的张量,axis(或**dim**)是维度的列表。将计算出的均值赋值给一个新变量。
使用**torch.std(input, axis)**计算标准差。这里,input 是**张量**,**axis**(或**dim**)是维度的列表。将计算出的标准差赋值给一个新变量。
打印上面计算出的均值和标准差。
示例 1
以下 Python 程序演示了如何计算一维张量的均值和标准差。
# Python program to compute mean and standard # deviation of a 1D tensor # import the library import torch # Create a tensor T = torch.Tensor([2.453, 4.432, 0.754, -6.554]) print("T:", T) # Compute the mean and standard deviation mean = torch.mean(T) std = torch.std(T) # Print the computed mean and standard deviation print("Mean:", mean) print("Standard deviation:", std)
输出
T: tensor([ 2.4530, 4.4320, 0.7540, -6.5540]) Mean: tensor(0.2713) Standard deviation: tensor(4.7920)
示例 2
以下 Python 程序演示了如何在两个维度上计算二维张量的均值和标准差,即按行和按列计算。
# import necessary library import torch # create a 3x4 2D tensor T = torch.Tensor([[2,4,7,-6], [7,33,-62,23], [2,-6,-77,54]]) print("T:\n", T) # compute the mean and standard deviation mean = torch.mean(T) std = torch.std(T) print("Mean:", mean) print("Standard deviation:", std) # Compute column-wise mean and std mean = torch.mean(T, axis = 0) std = torch.std(T, axis = 0) print("Column-wise Mean:\n", mean) print("Column-wise Standard deviation:\n", std) # Compute row-wise mean and std mean = torch.mean(T, axis = 1) std = torch.std(T, axis = 1) print("Row-wise Mean:\n", mean) print("Row-wise Standard deviation:\n", std)
输出
T: tensor([[ 2., 4., 7., -6.], [ 7., 33., -62., 23.], [ 2., -6., -77., 54.]]) Mean: tensor(-1.5833) Standard deviation: tensor(36.2703) Column-wise Mean: tensor([ 3.6667, 10.3333, -44.0000, 23.6667]) Column-wise Standard deviation: tensor([ 2.8868, 20.2567, 44.7996, 30.0056]) Row-wise Mean: tensor([ 1.7500, 0.2500, -6.7500]) Row-wise Standard deviation: tensor([ 5.5603, 42.8593, 53.8602])
广告