PyTorch – 如何使用均值和标准差对图像进行归一化?
**Normalize()** 变换使用均值和标准差对图像进行归一化。**torchvision.transforms** 模块提供了许多重要的变换,可用于对图像数据执行不同类型的操作。
**Normalize()** 仅接受任何大小的张量图像。张量图像是一个 torch 张量。张量图像可能具有 n 个通道。**Normalize()** 变换对每个通道的张量图像进行归一化。
由于此变换仅支持张量图像,因此应先将 PIL 图像转换为 torch 张量。应用 **Normalize()** 变换后,我们将归一化的 torch 张量转换为 PIL 图像。
步骤
我们可以使用以下步骤来使用均值和标准差对图像进行归一化:
导入所需的库。在以下所有示例中,所需的 Python 库为 **torch、Pillow** 和 **torchvision**。请确保您已安装它们。
import torch import torchvision import torchvision.transforms as T from PIL import Image
读取输入图像。输入图像可以是 PIL 图像或 torch 张量。如果输入图像为 PIL 图像,请将其转换为 torch 张量。
img = Image.open('sunset.jpg') # convert image to torch tensor imgTensor = T.ToTensor()(img)
定义一个变换,使用均值和标准差对图像进行归一化。这里,我们使用 ImageNet 数据集的均值和标准差。
transform = T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
将上面定义的变换应用于输入图像以对图像进行归一化。
normalized_imgTensor = transform(imgTensor)
将归一化的张量图像转换为 PIL 图像。
normalized_img = T.ToPILImage()(normalized_imgTensor)
显示归一化的图像。
normalized _img.show()
输入图像
此图像用作以下所有示例中的输入文件。
示例 1
以下 Python 程序将输入图像归一化到均值和标准差。我们使用 ImageNet 数据集的均值和标准差。
# import required libraries import torch import torchvision.transforms as T from PIL import Image # Read the input image img = Image.open('sunset.jpg') # convert image to torch tensor imgTensor = T.ToTensor()(img) # define a transform to normalize the tensor transform = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # normalize the converted tensor using above defined transform normalized_imgTensor = transform(imgTensor) # convert the normalized tensor to PIL image normalized_img = T.ToPILImage()(normalized_imgTensor) # display the normalized PIL image normalized_img.show()
输出
它将产生以下输出:
示例 2
在此示例中,我们定义了一个 **Compose 变换** 来执行三个变换。
将 PIL 图像转换为张量图像。
归一化张量图像。
将归一化的图像张量转换为 PIL 图像。
# import required libraries import torch import torchvision.transforms as T from PIL import Image # read the input image img = Image.open('sunset.jpg') # define a transform to: # convert the PIL image to tensor # normalize the tensor # convert the tensor to PIL image transform = T.Compose([ T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), T.ToPILImage()]) # apply the above tensor on input image img = transform(img) img.show()
输出
它将产生以下输出:
广告