如何在 PyTorch 中裁剪图像中心?


要裁剪图像中心,我们使用 **CenterCrop()**。它是 torchvision.transforms 模块提供的众多转换之一。此模块包含许多可用于对图像数据进行操作的重要转换。

**CenterCrop()** 转换接受 PIL 和张量图像。张量图像是一个形状为 **[C, H, W]** 的 PyTorch 张量,其中 C 是通道数,H 是图像高度,W 是图像宽度。

此转换还接受一批张量图像。一批张量图像是一个形状为 **[B, C, H, W]** 的张量。**B** 是批次中的图像数量。如果图像既不是 PIL 图像也不是张量图像,则我们首先将其转换为张量图像,然后应用 **CenterCrop()** 转换。

语法

torchvision.transforms.CenterCrop(size)

参数

  • **size** – 期望的裁剪尺寸。**size** 是一个类似于 **(h, w)** 的序列,其中 **h** 和 **w** 分别是裁剪图像的高度和宽度。如果 **size** 是一个 **int**,则裁剪后的图像将为正方形图像。

它返回给定大小的裁剪图像。

步骤

我们可以使用以下步骤在中心裁剪给定大小的图像。

  • 导入所需的库。在以下所有示例中,所需的 Python 库为 **torch、Pillow** 和 **torchvision**。确保您已安装它们。

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
  • 读取输入图像。输入图像为 PIL 图像或形状为 [..., H, W] 的 torch 张量。

img = Image.open('lena.jpg')
  • 定义一个转换以在中心裁剪图像。对于矩形裁剪,裁剪大小为 (200,250),对于正方形裁剪,裁剪大小为 250。根据您的需要更改裁剪大小。

# transform for rectangular crop
transform = transforms.CenterCrop((200,250))

# transform for square crop
transform = transforms.CenterCrop(250)
  • 将上面定义的转换应用于输入图像以在中心裁剪图像。

img = transform(img)
  • 可视化裁剪后的图像

img.show()

输入图像

以下图像用作两个示例中的输入图像。

示例 1

以下 Python 程序演示了如何在中心裁剪图像。裁剪后的图像为正方形图像。在此程序中,我们将输入图像读取为 PIL 图像。

# Python program to crop an image at center
# import required libraries
import torch
import torchvision.transforms as transforms
from PIL import Image

# Read the image
img = Image.open('waves.png')

# define a transform to crop the image at center
transform = transforms.CenterCrop(250)

# crop the image using above defined transform
img = transform(img)

# visualize the image
img.show()

输出

它将产生以下输出 -

示例 2

此 Python 程序在中心裁剪图像,并给出指定的高度和宽度。在此程序中,我们将输入图像读取为 PIL 图像。

# Python program to crop an image at center
# import torch library
import torch
import torchvision.transforms as transforms
from PIL import Image

# Read the image
img = Image.open('waves.png')

# define a transform to crop the image at center
transform = transforms.CenterCrop((150,500))

# crop the image using above defined transform
img = transform(img)

# visualize the image
img.show()

输出

生成的输出图像将为 150px 高和 500px 宽。

示例 3

在此程序中,我们将输入图像读取为 OpenCV 图像。我们定义一个转换,它是三个转换的组合。我们首先将图像转换为张量图像,然后应用 **CenterCrop()**,最后将裁剪后的张量图像转换为 PIL 图像。

# import the required libraries
import torch
import torchvision.transforms as transforms
import cv2

# read the inputimage
img = cv2.imread('waves.png')

# convert image from BGR to RGB
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# Define a transform. It is a composition
# of three transforms
transform = transforms.Compose([
   transforms.ToTensor(),        # Converts to PyTorch Tensor
   transforms.CenterCrop(250),   # crops at center
   transforms.ToPILImage()       # converts the tensor to PIL image
])
# apply the above transform to crop the image
img = transform(img)

# display the cropped image
img.show()

输出

它将产生以下输出 -

更新于: 2022年1月6日

4K+ 次查看

启动您的 职业生涯

通过完成课程获得认证

开始
广告