如何在 PyTorch 中找到图像通道的平均值?


RGB 图像有三个通道:红色、绿色和蓝色。我们需要计算这些图像通道上图像像素值的平均值。为此,我们使用 torch.mean() 方法。但是此方法的输入参数是 PyTorch 张量。因此,我们首先将图像转换为 PyTorch 张量,然后应用此方法。它返回张量中所有元素的平均值。要查找图像通道的平均值,我们将参数 dim 设置为 [1,2]

步骤

  • 导入所需的库。在以下所有 Python 示例中,所需的 Python 库是 torch、torchvision、PillowOpenCV。确保您已安装它们。

  • 使用 image.open() 读取输入图像并将其分配给变量 “img”

  • 定义一个转换,将 PIL 图像转换为 PyTorch 张量。

  • 使用上面定义的转换将图像 “img” 转换为 PyTorch 张量,并将此张量分配给 “imgTensor”

  • 计算 torch.mean(imgTensor, dim = [1,2])。它返回一个包含三个值的张量。这三个值是三个通道 RGB 的平均值。您可以将这三个平均值分别分配给三个新变量 “R_mean”、“G_mean”“B_mean”

  • 打印图像像素的三个平均值 “R_mean”、“G_mean”“B_mean”

输入图像

我们将在两个示例中都使用以下图像作为输入。

Explore our latest online courses and learn new skills at your own pace. Enroll and become a certified expert to boost your career.

示例 1

# Python program to find mean across the image channels # import necessary libraries import torch from PIL import Image import torchvision.transforms as transforms # Read the input image img = Image.open('opera.jpg') # Define transform to convert the image to PyTorch Tensor transform = transforms.ToTensor() # Convert image to PyTorch Tensor (Image Tensor) imgTensor = transform(img) print("Shape of Image Tensor:\n", imgTensor.shape) # Compute mean of the Image Tensor across image channels RGB R_mean, G_mean ,B_mean = torch.mean(imgTensor, dim = [1,2]) # print mean across image channel RGB print("Mean across Read channel:", R_mean) print("Mean across Green channel:", G_mean) print("Mean across Blue channel:", B_mean)

输出

Shape of Image Tensor:
   torch.Size([3, 447, 640])
Mean across Read channel: tensor(0.1487)
Mean across Green channel: tensor(0.1607)
Mean across Blue channel: tensor(0.2521)

示例 2

我们还可以使用 OpenCV 读取图像。使用 OpenCV 读取的图像类型为 numpy.ndarray。在这里,在这个示例中,我们使用了一种不同的方法来计算平均值。我们使用 imgTensor.mean(),这是张量上的基本运算。请查看以下示例。

# Python program to find mean across the image channels # import necessary libraries import torch import cv2 import torchvision.transforms as transforms # Read the input image either using cv2 or PIL img = cv2.imread('opera.jpg') img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Define transform to convert the image to PyTorch Tensor transform = transforms.ToTensor() # Convert image to PyTorch Tensor (Image Tensor) imgTensor = transform(img) print("Shape of Image Tensor:\n", imgTensor.shape) # compute mean of the Image Tensor across image channels RGB # The other way to compute the mean R_mean, G_mean ,B_mean = imgTensor.mean(dim = [1,2]) # print mean across image channel RGB print("Mean across Read channel:", R_mean) print("Mean across Green channel:", G_mean) print("Mean across Blue channel:", B_mean)

输出

Shape of Image Tensor:
   torch.Size([3, 447, 640])
Mean across Read channel: tensor(0.1487)
Mean across Green channel: tensor(0.1607)
Mean across Blue channel: tensor(0.2521)

更新于: 2021 年 11 月 6 日

2K+ 浏览量

开启您的 职业生涯

通过完成课程获得认证

开始学习
广告