PyTorch - 数据集



在本章中,我们将重点讨论 torchvision.datasets 及其各种类型。PyTorch 包含以下数据集加载器 -

  • MNIST
  • COCO(字幕和检测)

数据集主要包含以下两种类型的函数 -

  • 变换 - 一个函数,接收一张图像并返回标准内容的修改版本。这些内容可以与变换结合在一起。

  • 目标变换 - 一个函数,接收目标并将其变换。例如,接收字幕字符串并返回文字索引的张量。

MNIST

以下是 MNIST 数据集的示例代码 -

dset.MNIST(root, train = TRUE, transform = NONE, 
target_transform = None, download = FALSE)

参数如下 -

  • root - 已处理数据所在的​​数据集的根目录。

  • train - True = 训练集,False = 测试集

  • download - True = 从互联网下载数据集并将其放入根目录。

COCO

这需要安装 COCO API。以下示例用于演示使用 PyTorch 实现 COCO 数据集 -

import torchvision.dataset as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = ‘ dir where images are’, 
annFile = ’json annotation file’,
transform = transforms.ToTensor())
print(‘Number of samples: ‘, len(cap))
print(target)

获得的输出如下 -

Number of samples: 82783
Image Size: (3L, 427L, 640L)
广告