PyTorch – torchvision.transforms – RandomErasing()
**RandomErasing()** 变换会在输入图像中随机选择一个矩形区域,并擦除其像素。**torchvision.transforms** 模块提供了许多重要的变换,可用于对图像数据执行不同类型的操作。**RandomErasing()** 变换只接受任意大小的张量图像。张量图像是一个torch张量。
由于此变换仅支持张量图像,因此应首先将 PIL 图像转换为 torch 张量。应用 **RandomErasing()** 变换后,我们将 torch 张量图像转换为 PIL 图像。
步骤
我们可以使用以下步骤在输入图像中随机选择一个矩形区域并擦除其像素:
导入所需的库。在以下所有示例中,所需的 Python 库为 **torch、Pillow** 和 **torchvision**。确保您已安装它们。
import torch import torchvision import torchvision.transforms as T from PIL import Image
读取输入图像。输入图像可以是 PIL 图像或 torch 张量。
img = Image.open('sky.jpg')
如果输入图像是 PIL 图像,请将其转换为 torch 张量。
imgTensor = T.ToTensor()(img)
定义 **RandomErasing()** 变换。
transform = T.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False)
将上述定义的变换应用于输入图像,以随机选择输入图像中的矩形区域并擦除其像素。
imgTensor = transform(imgTensor)
将上述变换后的张量图像转换为 PIL 图像。
img = T.ToPILImage()imgTensor)
显示归一化图像。
img.show()
注意
或者,我们可以定义上述在第 3、4 和 6 步中执行的三个变换的组合。
transform = T.Compose([ T.ToTensor(), T.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False), T.ToPILImage() ])
输入图像
此图像在以下所有示例中用作输入文件。
示例 1
以下程序会从输入图像中随机选择一个矩形区域并擦除其像素。在这里,我们将概率设置为 1,因此它一定会选择图像中的一个区域并擦除其像素。
# import required libraries import torch import torchvision.transforms as T from PIL import Image # read the input image img = Image.open('sky.jpg') # define a transform to perform three transformations: # convert PIL image to tensor # randomly select a rectangle region in a torch Tensor image # and erase its pixels # convert the tensor to PIL image transform = T.Compose([ T.ToTensor(), T.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False), T.ToPILImage() ]) # apply the transform on image img = transform(img) # display the output image img.show()
输出
它将产生以下输出:
示例 2
让我们来看另一个例子:
import torch import torchvision.transforms as T from PIL import Image import matplotlib.pyplot as plt # read input image img = Image.open('sky.jpg') # define a transform to perform transformations transform = T.Compose([T.ToTensor(), T.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False), T.ToPILImage()]) # use dictionary comprehension to take four output images imgs = [transform(img) for _ in range(4)] # display four output images fig = plt.figure(figsize=(7,4)) rows, cols = 2,2 for j in range(0, len(imgs)): fig.add_subplot(rows, cols, j+1) plt.imshow(imgs[j]) plt.xticks([]) plt.yticks([]) plt.show()
输出
它将产生以下输出:
请注意,在上图输出图像中,至少有两张图像中有被擦除的区域,因为我们将概率设置为 0.5。
广告