PyTorch – 如何在随机位置裁剪图像?


要在随机位置裁剪图像,我们应用**RandomCrop()**变换。这是**torchvision.transforms**模块提供的众多重要变换之一。

**RandomCrop()**变换接受PIL图像和张量图像。张量图像是一个形状为**[C, H, W]**的torch张量,其中C是通道数,H是图像高度,W是图像宽度。

如果图像既不是PIL图像也不是张量图像,那么我们首先将其转换为张量图像,然后应用**RandomCrop()**。

语法

torchvision.transforms.RandomCrop(size)(img)

其中size是所需的裁剪大小。size是一个类似于**(h,w)**的序列,其中**h**和**w**分别是裁剪图像的高度和宽度。如果**size**是**int**,则裁剪后的图像将是正方形图像。

它返回在给定大小的随机位置裁剪的图像。

步骤

我们可以使用以下步骤在给定大小的随机位置裁剪图像:

  • 导入所需的库。在以下所有示例中,所需的Python库是**torch、Pillow**和**torchvision**。确保您已安装它们。

import torch
import torchvision
import torchvision.transforms as T
from PIL import Image
  • 读取输入图像。输入图像为PIL图像或torch张量。

img = Image.open('meteor.jpg')
  • 定义一个变换,以便在随机位置裁剪图像。矩形裁剪的裁剪大小为(200,250),正方形裁剪的裁剪大小为250。根据您的需要更改裁剪大小。

# transform for rectangular crop
transform = T.RandomCrop((200,250))
# transform for square crop
transform = T.RandomCrop(250)
  • 将上述定义的变换应用于输入图像,以便在随机位置裁剪图像。

img = transform(img)
  • 可视化裁剪后的图像。

img.show()

输入图像

此图像用作以下所有示例中的输入。

示例1

以下Python 3程序显示了如何在随机位置裁剪输入PIL图像。

# import required libraries
import torch
import torchvision.transforms as T
from PIL import Image

# read the input image
img = Image.open('meteor.png')

# define transform to crop the image at
# random location
transform = T.RandomCrop((250,500))
img = transform(img)
img.show()

输出

它将产生以下输出:

示例2

import torch
import torchvision.transforms as T
from PIL import Image

img = Image.open('lena.jpg')
transform = T.RandomCrop((250,500), padding=50)
img = transform(img)
img.show()

输出

它将产生以下输出。请注意,填充是随机的。

示例3

# import required libraries
import torch
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt

# read the input image
img = Image.open('meteor.png')

# define the transform with crop size
transform = T.RandomCrop((100,150))

# crop four images
imgs = [transform(img) for _ in range(4)]

# display these cropped images
fig = plt.figure(figsize=(7,3))
rows, cols = len(imgs),1
for j in range(0, len(imgs)):
   fig.add_subplot(rows, cols, j+1)
   plt.imshow(imgs[j])
   #plt.xticks([])
   #plt.yticks([])
plt.show()

输出

它将产生以下输出:

更新于:2022年1月6日

3K+ 次浏览

启动你的职业生涯

完成课程获得认证

开始学习
广告