如何在 PyTorch 中创建图像网格?
torchvision.utils 包提供了 make_grid() 函数来创建图像网格。图像应为 torch 张量。它接受形状为 (B ☓ C ☓ H ☓ W) 的 4D 小批量张量或大小相同的张量图像列表。
这里,B 是批大小,C 是图像中通道的数量,H 和 W 分别是高度和宽度。
所有图像的 H ☓ W 应该相同。
此函数的输出是一个包含图像网格的 torch 张量。我们可以使用 nrow 参数指定一行中的图像数量。我们还有许多其他参数来控制网格输出。要可视化图像网格,我们首先将整个网格转换为 PIL 图像。
语法
torchvision.utils.make_grid(tensor)
参数
tensor - 张量或张量列表。形状为 (B x C x H x W) 的 4D 小批量张量或大小相同的图像列表。
输出
它返回一个包含图像网格的 torch 张量。
步骤
导入所需的库。在以下所有示例中,所需的 Python 库为 torch 和 torchvision。确保您已安装它们。
import torch import torchvision from torchvision.io import read_image from torchvision.utils import make_grid
使用 image_read() 函数读取多个 JPEG 或 PNG 图像。使用图像类型(.jpg 或 .png)指定完整图像路径。此函数的输出是一个大小为 [image_channels, image_height, image_width] 的 torch 张量。
img1 = read_image('elephant.jpg') img2 = read_image('cat.jpg') img3 = read_image('dog.jpg')
使用 make_grid() 函数创建读取为 torch 张量的输入图像网格。指定 nrow 以在网格中每行显示的图像数量。
grid = make_grid([img1, img2, img3], nrow=3)
将网格张量转换为 PIL 图像并显示它。
img = torchvision.transforms.ToPILImage()(grid) img.show()
示例 1
在此 Python 程序中,我们读取三个输入图像并创建这些图像的网格。
import torch import torchvision from torchvision.io import read_image from torchvision.utils import make_grid # read images img1 = read_image('elephant.jpg') img2 = read_image('cat.jpg') img3 = read_image('dog.jpg') print("size of img1:", img1.size()) print("size of img2:", img2.size()) print("size of img3:", img3.size()) # make grid grid = make_grid([img1, img2, img3]) print("size of grid:", grid.size()) # print("grid:
", grid) img = torchvision.transforms.ToPILImage()(grid) img.show()
输出
size of img1: torch.Size([3, 466, 700]) size of img2: torch.Size([3, 466, 700]) size of img3: torch.Size([3, 466, 700]) size of grid: torch.Size([3, 470, 2108])
示例 2
在以下 Python 程序中,我们读取四个输入图像并创建这些图像的网格。我们将 nrow 设置为 2,以便网格每行有两个图像。
# Import the required library import torch import torchvision from torchvision.io import read_image from torchvision.utils import make_grid # read images img1 = read_image('elephant.jpg') # img1 = read_image('car.jpg') print("Size of image:",img1.size()) img2 = read_image('cat.jpg') img3 = read_image('dog.jpg') img4 = read_image('leopard.jpg') # make grid grid = make_grid([img1, img2, img3, img4], nrow = 2) print("size of grid:", grid.size()) # print("grid:
", grid) img = torchvision.transforms.ToPILImage()(grid) img.show()
输出
Size of image: torch.Size([3, 466, 700]) size of grid: torch.Size([3, 938, 1406])
广告