PyTorch – torchvision.transforms – RandomResizedCrop()
RandomResizedCrop() 变换会裁剪原始输入图像的随机区域。此裁剪大小是随机选择的,最后裁剪后的图像将调整为给定大小。RandomResizedCrop() 变换是torchvision.transforms 模块提供的众多变换之一。此模块包含许多重要的变换,可用于对图像数据执行不同类型的操作。
RandomResizedCrop() 接受 PIL 和张量图像。张量图像是一个 PyTorch 张量,形状为[..., H, W],其中 ... 表示任意数量的维度,H 是图像高度,W 是图像宽度。如果图像既不是 PIL 图像也不是张量图像,则首先将其转换为张量图像,然后应用变换。
语法
torchvision.transforms.RandomResizedCrop(size)(img)
其中 size 是所需的裁剪大小。size 是一个类似于(h, w)的序列,其中 h 和 w 分别是裁剪图像的高度和宽度。如果size 是一个int,则裁剪后的图像为正方形图像。
它返回使用给定大小调整大小的裁剪图像。
步骤
我们可以使用以下步骤裁剪输入图像的随机部分并将其调整为给定大小:
导入所需的库。在以下所有示例中,所需的 Python 库为torch、Pillow 和torchvision。确保您已安装它们。
import torch import torchvision import torchvision.transforms as T from PIL import Image import matplotlib.pyplot as plt
读取输入图像。输入图像为 PIL 图像或形状为 [..., H, W] 的 torch 张量。
img = Image.open('baseball.png')
定义一个变换,以裁剪输入图像上的随机部分,然后调整为给定大小。此处给定大小为 (150,250) 用于矩形裁剪,250 用于正方形裁剪。根据您的需要更改裁剪大小。
# transform for rectangular crop transform = T.RandomResizedCrop((150,250)) # transform for square crop transform = T.RandomResizedCrop(250)
将上述定义的变换应用于输入图像,以裁剪输入图像上的随机部分,然后将其调整为给定大小。
cropped_img = transform(img)
显示裁剪后的图像,然后显示调整大小的图像
cropped_img.show()
输入图像
此图像用作以下所有示例中的输入。
示例 1
在此程序中,裁剪输入图像的随机部分,然后将其大小调整为 (150, 250)。
# import required libraries import torch import torchvision.transforms as T from PIL import Image import matplotlib.pyplot as plt # read the input image img = Image.open('baseball.png') # define a transform to crop a random portion of an image # and resize it to given size transform = T.RandomResizedCrop(size=(350,600)) # apply above defined transform to the input image img = transform(img) # display the cropped image img.show()
输出
它将产生以下输出:
示例 2
import torch import torchvision.transforms as T from PIL import Image import matplotlib.pyplot as plt img = Image.open('baseball.png') transform = T.RandomResizedCrop(size = (200,150), scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333)) imgs = [transform(img) for _ in range(4)] fig = plt.figure(figsize=(7,3)) rows, cols = 2,2 for j in range(0, len(imgs)): fig.add_subplot(rows, cols, j+1) plt.imshow(imgs[j]) plt.xticks([]) plt.yticks([]) plt.show()
输出
它将产生以下输出: