如何在PyTorch中加载计算机视觉数据集?
PyTorch中有许多与计算机视觉任务相关的可用数据集。**torch.utils.data.Dataset** 提供不同类型的数据集。**torchvision.datasets** 是 **torch.utils.data.Dataset** 的子类,包含许多与图像和视频相关的数据集。PyTorch还提供了一个 **torch.utils.data.DataLoader**,用于从数据集中加载多个样本。
步骤
我们可以使用以下步骤加载计算机视觉数据集:
导入所需的库。在以下所有示例中,所需的Python库为**torch**、**Matplotlib** 和 **torchvision**。确保您已安装它们。
import torch import torchvision from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt
我们使用 **datasets.CIFAR10()** 加载 CIFAR10 训练和测试数据集,参数 **train=True** 用于训练数据集,**train=False** 用于测试数据集。
root="data", train=True, download=True, transform=ToTensor()
定义训练数据加载器 (**trainloader**) 和测试数据加载器 (**testloader**)。指定 **batch_size**。设置 **Shuffle=True** 以获得随机排列的图像。还可以访问类标签名称。
从训练或测试数据集中获取一些随机图像和标签。
dataiter = iter(trainloader) images, labels = dataiter.next()
使用标签可视化获得的图像。
示例 1
在下面的 Python 程序中,我们加载 CIFAR10 训练和测试数据集。
# Import the required libraries import torch import torchvision from torchvision import datasets from torchvision.transforms import ToTensor # define batch size batch_size = 4 # download CIFAR10 training and test datasets training_data = datasets.CIFAR10( root="data", train=True, download=True, transform=ToTensor() ) test_data = datasets.CIFAR10( root="data", train=False, download=True, transform=ToTensor() ) # define train and test dataloader trainloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=2) testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2) # access names of the labels label_names = training_data.classes # display details about the dataset print("label_names:
", label_names) print("class label name to index:
", training_data.class_to_idx) print("Shape of training data:
", training_data.data.shape ) print("Shape of test data:
", test_data.data.shape )
输出
Files already downloaded and verified Files already downloaded and verified label_names: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] class label name to index: {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9} Shape of training data: (50000, 32, 32, 3) Shape of test data: (10000, 32, 32, 3)
示例 2
在这个 Python 程序中,我们加载 CIFAR10 数据集。我们还可视化一些带有标签名称的随机图像。
import torch import torchvision from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt batch_size = 4 training_data = datasets.CIFAR10( root="data", train=True, download=True, transform=ToTensor() ) test_data = datasets.CIFAR10( root="data", train=False, download=True, transform=ToTensor() ) trainloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=False, num_workers=2) testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2) label_names = training_data.classes # get some random training images dataiter = iter(trainloader) images, labels = dataiter.next() # display random images # define figure fig=plt.figure(figsize=(8, 5)) columns, rows = batch_size, 1 # visualize these random images for i in range(1, columns*rows +1): fig.add_subplot(rows, columns, i) plt.imshow(images[i-1].numpy().transpose(1,2,0)) plt.xticks([]) plt.yticks([]) plt.title(label_names[labels[i-1]]) plt.show()
输出
Files already downloaded and verified Files already downloaded and verified
广告