PyTorch - 加载数据



PyTorch 包含一个名为 TorchVision 的包,用于加载和准备数据集。它包含两个基本函数,分别是 Dataset 和 DataLoader,可帮助转换和加载数据集。

数据集

数据集用于读取和转换给定数据集中的数据点。实现的基本语法如下所示 −

trainset = torchvision.datasets.CIFAR10(root = './data', train = True,
   download = True, transform = transform)

DataLoader 用于混洗和批处理数据。它可以用于并行加载具有多进程工作程序的数据。

trainloader = torch.utils.data.DataLoader(trainset, batch_size = 4,
   shuffle = True, num_workers = 2)

示例: 加载 CSV 文件

我们使用 Python 包 Panda 加载 csv 文件。原始文件具有以下格式:(图像名称、68 个地标 - 每个地标都有 x、y 坐标)。

landmarks_frame = pd.read_csv('faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)
广告