PyTorch – torchvision.transforms – RandomErasing()
**RandomErasing()** 变换会在输入图像中随机选择一个矩形区域,并擦除其像素。**torchvision.transforms** 模块提供了许多重要的变换,可用于对图像数据执行不同类型的操作。**RandomErasing()** 变换只接受任意大小的张量图像。张量图像是一个torch张量。
由于此变换仅支持张量图像,因此应首先将 PIL 图像转换为 torch 张量。应用 **RandomErasing()** 变换后,我们将 torch 张量图像转换为 PIL 图像。
步骤
我们可以使用以下步骤在输入图像中随机选择一个矩形区域并擦除其像素:
导入所需的库。在以下所有示例中,所需的 Python 库为 **torch、Pillow** 和 **torchvision**。确保您已安装它们。
import torch import torchvision import torchvision.transforms as T from PIL import Image
读取输入图像。输入图像可以是 PIL 图像或 torch 张量。
img = Image.open('sky.jpg')如果输入图像是 PIL 图像,请将其转换为 torch 张量。
imgTensor = T.ToTensor()(img)
定义 **RandomErasing()** 变换。
transform = T.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False)
将上述定义的变换应用于输入图像,以随机选择输入图像中的矩形区域并擦除其像素。
imgTensor = transform(imgTensor)
将上述变换后的张量图像转换为 PIL 图像。
img = T.ToPILImage()imgTensor)
显示归一化图像。
img.show()
注意
或者,我们可以定义上述在第 3、4 和 6 步中执行的三个变换的组合。
transform = T.Compose([ T.ToTensor(), T.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False), T.ToPILImage() ])
输入图像
此图像在以下所有示例中用作输入文件。

示例 1
以下程序会从输入图像中随机选择一个矩形区域并擦除其像素。在这里,我们将概率设置为 1,因此它一定会选择图像中的一个区域并擦除其像素。
# import required libraries
import torch
import torchvision.transforms as T
from PIL import Image
# read the input image
img = Image.open('sky.jpg')
# define a transform to perform three transformations:
# convert PIL image to tensor
# randomly select a rectangle region in a torch Tensor image
# and erase its pixels
# convert the tensor to PIL image
transform = T.Compose([ T.ToTensor(), T.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False), T.ToPILImage() ])
# apply the transform on image
img = transform(img)
# display the output image
img.show()输出
它将产生以下输出:
.jpg)
示例 2
让我们来看另一个例子:
import torch
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt
# read input image
img = Image.open('sky.jpg')
# define a transform to perform transformations
transform = T.Compose([T.ToTensor(), T.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False), T.ToPILImage()])
# use dictionary comprehension to take four output images
imgs = [transform(img) for _ in range(4)]
# display four output images
fig = plt.figure(figsize=(7,4))
rows, cols = 2,2
for j in range(0, len(imgs)):
fig.add_subplot(rows, cols, j+1)
plt.imshow(imgs[j])
plt.xticks([])
plt.yticks([])
plt.show()输出
它将产生以下输出:
.jpg)
请注意,在上图输出图像中,至少有两张图像中有被擦除的区域,因为我们将概率设置为 0.5。
广告
数据结构
网络
关系数据库管理系统 (RDBMS)
操作系统
Java
iOS
HTML
CSS
Android
Python
C语言编程
C++
C#
MongoDB
MySQL
Javascript
PHP