如何在PyTorch中加载计算机视觉数据集?
PyTorch中有许多与计算机视觉任务相关的可用数据集。**torch.utils.data.Dataset** 提供不同类型的数据集。**torchvision.datasets** 是 **torch.utils.data.Dataset** 的子类,包含许多与图像和视频相关的数据集。PyTorch还提供了一个 **torch.utils.data.DataLoader**,用于从数据集中加载多个样本。
步骤
我们可以使用以下步骤加载计算机视觉数据集:
导入所需的库。在以下所有示例中,所需的Python库为**torch**、**Matplotlib** 和 **torchvision**。确保您已安装它们。
import torch import torchvision from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt
我们使用 **datasets.CIFAR10()** 加载 CIFAR10 训练和测试数据集,参数 **train=True** 用于训练数据集,**train=False** 用于测试数据集。
root="data", train=True, download=True, transform=ToTensor()
定义训练数据加载器 (**trainloader**) 和测试数据加载器 (**testloader**)。指定 **batch_size**。设置 **Shuffle=True** 以获得随机排列的图像。还可以访问类标签名称。
从训练或测试数据集中获取一些随机图像和标签。
dataiter = iter(trainloader) images, labels = dataiter.next()
使用标签可视化获得的图像。
示例 1
在下面的 Python 程序中,我们加载 CIFAR10 训练和测试数据集。
# Import the required libraries
import torch
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor
# define batch size
batch_size = 4
# download CIFAR10 training and test datasets
training_data = datasets.CIFAR10(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.CIFAR10(
root="data",
train=False,
download=True,
transform=ToTensor()
)
# define train and test dataloader
trainloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)
# access names of the labels
label_names = training_data.classes
# display details about the dataset
print("label_names:
", label_names)
print("class label name to index:
", training_data.class_to_idx)
print("Shape of training data:
", training_data.data.shape )
print("Shape of test data:
", test_data.data.shape )输出
Files already downloaded and verified
Files already downloaded and verified
label_names:
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog',
'frog', 'horse', 'ship', 'truck']
class label name to index:
{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer':
4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
Shape of training data:
(50000, 32, 32, 3)
Shape of test data:
(10000, 32, 32, 3)示例 2
在这个 Python 程序中,我们加载 CIFAR10 数据集。我们还可视化一些带有标签名称的随机图像。
import torch import torchvision from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt batch_size = 4 training_data = datasets.CIFAR10( root="data", train=True, download=True, transform=ToTensor() ) test_data = datasets.CIFAR10( root="data", train=False, download=True, transform=ToTensor() ) trainloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, shuffle=False, num_workers=2) testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2) label_names = training_data.classes # get some random training images dataiter = iter(trainloader) images, labels = dataiter.next() # display random images # define figure fig=plt.figure(figsize=(8, 5)) columns, rows = batch_size, 1 # visualize these random images for i in range(1, columns*rows +1): fig.add_subplot(rows, columns, i) plt.imshow(images[i-1].numpy().transpose(1,2,0)) plt.xticks([]) plt.yticks([]) plt.title(label_names[labels[i-1]]) plt.show()
输出
Files already downloaded and verified Files already downloaded and verified

广告
数据结构
网络
关系数据库管理系统 (RDBMS)
操作系统
Java
iOS
HTML
CSS
Android
Python
C语言编程
C++
C#
MongoDB
MySQL
Javascript
PHP