PyTorch – 五裁剪变换


为了将给定图像裁剪成四个角和中心裁剪,我们应用**FiveCrop()**变换。这是torchvision.transforms模块提供的众多变换之一。此模块包含许多重要的变换,可用于对图像数据执行不同类型的操作。

**FiveCrop()**变换接受PIL图像和张量图像。张量图像是一个形状为**[C, H, W]**的torch张量,其中C是通道数,H是图像高度,W是图像宽度。如果图像既不是PIL图像也不是张量图像,则我们首先将其转换为张量图像,然后应用**FiveCrop**变换。

语法

torchvision.transforms.FiveCrop(size)

其中size是所需的裁剪大小。size是一个类似于**(h, w)**的序列,其中h和w分别是每个裁剪图像的高度和宽度。如果**size**是**int**,则裁剪的图像是正方形的。

它返回一个包含五个裁剪图像的元组,四个角图像和一个中心图像。

步骤

我们可以使用以下步骤将图像裁剪成四个图像和一个给定大小的中心裁剪:

  • 导入所需的库。在以下所有示例中,所需的Python库是**torch、Pillow**和**torchvision**。确保您已安装它们。

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
  • 读取输入图像。输入图像为PIL图像或torch张量。

img = Image.open('recording.jpg')
  • 定义一个变换,将图像裁剪成四个角和中心裁剪。对于矩形裁剪,裁剪大小设置为(150, 300),对于正方形裁剪,裁剪大小设置为250。根据您的需要更改裁剪大小。

# transform for rectangular crop
transform = transforms.FiveCrop((200,250))

# transform for square crop
transform = transforms.FiveCrop(250)
  • 将上面定义的变换应用于输入图像,以将其裁剪成四个角和中心裁剪。

img = transform(img)
  • 显示所有五个裁剪图像。

输入图像

我们将在以下两个示例中使用此图像。

示例1

在下面的Python3程序中,我们裁剪四个角和一个中心裁剪。五个裁剪图像都是矩形的。

# import required libraries
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

# Read the image
img = Image.open('recording.jpg')

# define a transform to crop the image into four
# corners and the central crop
transform = transforms.FiveCrop((150, 300))

# apply the above transform on the image
imgs = transform(img)

# This transform returns a tuple of 5 images
print(type(imgs))
print("Total cropped images:",len(imgs))

输出

<class 'tuple'>
Total cropped images: 5

示例2

在下面的Python3程序中,我们裁剪四个角和一个中心裁剪。五个裁剪图像都是正方形的。

# import required libraries
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

# Read the image
img = Image.open('recording.jpg')

# define a transform to crop the image into four
# corners and the central crop
transform = transforms.FiveCrop(200)

# apply the above transform on the image
imgs = transform(img)

# Define a figure of size (8, 8)
fig=plt.figure(figsize=(8, 8))

# Define row and cols in the figure
rows, cols = 1, 5

# Display all 5 cropped images
for j in range(0, cols*rows):
   fig.add_subplot(rows, cols, j+1)
   plt.imshow(imgs[j])
   plt.xticks([])
   plt.yticks([])
plt.show()

输出

它将产生以下输出:

更新于:2022年1月6日

767 次浏览

启动您的职业生涯

完成课程获得认证

开始学习
广告