如何在 PyTorch 中创建图像网格?


torchvision.utils 包提供了 make_grid() 函数来创建图像网格。图像应为 torch 张量。它接受形状为 (B ☓ C ☓ H ☓ W) 的 4D 小批量张量或大小相同的张量图像列表。

  • 这里,B 是批大小,C 是图像中通道的数量,HW 分别是高度和宽度。

  • 所有图像的 H ☓ W 应该相同。

此函数的输出是一个包含图像网格的 torch 张量。我们可以使用 nrow 参数指定一行中的图像数量。我们还有许多其他参数来控制网格输出。要可视化图像网格,我们首先将整个网格转换为 PIL 图像

语法

torchvision.utils.make_grid(tensor)

参数

  • tensor - 张量或张量列表。形状为 (B x C x H x W) 的 4D 小批量张量或大小相同的图像列表。

输出

它返回一个包含图像网格的 torch 张量。

步骤

  • 导入所需的库。在以下所有示例中,所需的 Python 库为 torchtorchvision。确保您已安装它们。

import torch
import torchvision
from torchvision.io import read_image
from torchvision.utils import make_grid
  • 使用 image_read() 函数读取多个 JPEGPNG 图像。使用图像类型(.jpg.png)指定完整图像路径。此函数的输出是一个大小为 [image_channels, image_height, image_width] 的 torch 张量。

img1 = read_image('elephant.jpg')
img2 = read_image('cat.jpg')
img3 = read_image('dog.jpg')
  • 使用 make_grid() 函数创建读取为 torch 张量的输入图像网格。指定 nrow 以在网格中每行显示的图像数量。

grid = make_grid([img1, img2, img3], nrow=3)
  • 将网格张量转换为 PIL 图像并显示它。

img = torchvision.transforms.ToPILImage()(grid)
img.show()

示例 1

在此 Python 程序中,我们读取三个输入图像并创建这些图像的网格。

import torch
import torchvision
from torchvision.io import read_image
from torchvision.utils import make_grid

# read images
img1 = read_image('elephant.jpg')
img2 = read_image('cat.jpg')
img3 = read_image('dog.jpg')
print("size of img1:", img1.size())
print("size of img2:", img2.size())
print("size of img3:", img3.size())

# make grid
grid = make_grid([img1, img2, img3])
print("size of grid:", grid.size())

# print("grid:
", grid) img = torchvision.transforms.ToPILImage()(grid) img.show()

输出

size of img1: torch.Size([3, 466, 700])
size of img2: torch.Size([3, 466, 700])
size of img3: torch.Size([3, 466, 700])
size of grid: torch.Size([3, 470, 2108])

示例 2

在以下 Python 程序中,我们读取四个输入图像并创建这些图像的网格。我们将 nrow 设置为 2,以便网格每行有两个图像。

# Import the required library
import torch
import torchvision
from torchvision.io import read_image
from torchvision.utils import make_grid

# read images
img1 = read_image('elephant.jpg')

# img1 = read_image('car.jpg')
print("Size of image:",img1.size())
img2 = read_image('cat.jpg')
img3 = read_image('dog.jpg')
img4 = read_image('leopard.jpg')

# make grid
grid = make_grid([img1, img2, img3, img4], nrow = 2)
print("size of grid:", grid.size())

# print("grid:
", grid) img = torchvision.transforms.ToPILImage()(grid) img.show()

输出

Size of image: torch.Size([3, 466, 700])
size of grid: torch.Size([3, 938, 1406])

更新于: 2022年1月20日

4K+ 阅读量

启动你的 职业生涯

通过完成课程获得认证

开始学习
广告