如何在PyTorch中加载计算机视觉数据集?


PyTorch中有许多与计算机视觉任务相关的可用数据集。**torch.utils.data.Dataset** 提供不同类型的数据集。**torchvision.datasets** 是 **torch.utils.data.Dataset** 的子类,包含许多与图像和视频相关的数据集。PyTorch还提供了一个 **torch.utils.data.DataLoader**,用于从数据集中加载多个样本。

步骤

我们可以使用以下步骤加载计算机视觉数据集:

  • 导入所需的库。在以下所有示例中,所需的Python库为**torch**、**Matplotlib** 和 **torchvision**。确保您已安装它们。

import torch
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
  • 我们使用 **datasets.CIFAR10()** 加载 CIFAR10 训练和测试数据集,参数 **train=True** 用于训练数据集,**train=False** 用于测试数据集。

root="data",
train=True,
download=True,
transform=ToTensor()
  • 定义训练数据加载器 (**trainloader**) 和测试数据加载器 (**testloader**)。指定 **batch_size**。设置 **Shuffle=True** 以获得随机排列的图像。还可以访问类标签名称。

  • 从训练或测试数据集中获取一些随机图像和标签。

dataiter = iter(trainloader)
images, labels = dataiter.next()
  • 使用标签可视化获得的图像。

示例 1

在下面的 Python 程序中,我们加载 CIFAR10 训练和测试数据集。

# Import the required libraries
import torch
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor

# define batch size
batch_size = 4

# download CIFAR10 training and test datasets
training_data = datasets.CIFAR10(
   root="data",
   train=True,
   download=True,
   transform=ToTensor()
)

test_data = datasets.CIFAR10(
   root="data",
   train=False,
   download=True,
   transform=ToTensor()
)

# define train and test dataloader
trainloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=2)

testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)

# access names of the labels
label_names = training_data.classes

# display details about the dataset
print("label_names:
", label_names) print("class label name to index:
", training_data.class_to_idx) print("Shape of training data:
", training_data.data.shape ) print("Shape of test data:
", test_data.data.shape )

输出

Files already downloaded and verified
Files already downloaded and verified
label_names:
   ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog',
      'frog', 'horse', 'ship', 'truck']
class label name to index:
   {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer':
      4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
Shape of training data:
   (50000, 32, 32, 3)
Shape of test data:
   (10000, 32, 32, 3)

示例 2

在这个 Python 程序中,我们加载 CIFAR10 数据集。我们还可视化一些带有标签名称的随机图像。

import torch
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

batch_size = 4
training_data = datasets.CIFAR10(
   root="data",
   train=True,
   download=True,
   transform=ToTensor()
)

test_data = datasets.CIFAR10(
   root="data",
   train=False,
   download=True,
   transform=ToTensor()
)

trainloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=False, num_workers=2)

testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)

label_names = training_data.classes

# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# display random images
# define figure
fig=plt.figure(figsize=(8, 5))
columns, rows = batch_size, 1

# visualize these random images
for i in range(1, columns*rows +1):
   fig.add_subplot(rows, columns, i)
   plt.imshow(images[i-1].numpy().transpose(1,2,0))
   plt.xticks([])
   plt.yticks([])
   plt.title(label_names[labels[i-1]])
plt.show()

输出

Files already downloaded and verified
Files already downloaded and verified

更新于:2022年1月25日

725 次浏览

启动您的职业生涯

通过完成课程获得认证

开始学习
广告