在 Pytorch 中加载数据


每个机器学习项目都依赖于数据,由 Facebook 创建的著名开源机器学习工具包 PyTorch 也不例外。本手册旨在简化 PyTorch 中的数据加载过程,并帮助您尽快开始使用。

本文将重点介绍 PyTorch 的 DataLoader、Dataset 和 Transform 类。我们将通过一些实际示例来帮助您理解这些 PyTorch 核心概念,并简化您的机器学习应用程序。

PyTorch 数据加载:简要概述

为了导入和准备数据,PyTorch 提供了一个强大且灵活的工具箱。三个关键要素是:

  • Dataset  这个抽象类代表一个数据集,它允许以任何格式加载数据。只需要重写两个方法 __getitem__() 和 __len__()。

  • DataLoader  它封装了一个 Dataset,并提供对底层数据的快速访问。它会自动构建批次、随机打乱数据,并使用多线程并行加载数据。

  • Transforms  这些是常见的图像修改。可以通过 Compose 将转换链接在一起。这使您可以创建一个预处理操作管道,可以将其应用于加载的数据。

将数据加载到 PyTorch:示例

考虑一个图像集合,其中每个图像都表示为一个 3D NumPy 数组,并且标签与图像分开存储。以下是如何将此数据添加到 PyTorch 的快速方法。

from torch.utils.data import Dataset, DataLoader
import numpy as np

class ImageDataset(Dataset):
   def __init__(self, images, labels):
      self.images = images
      self.labels = labels

   def __getitem__(self, index):
      return self.images[index], self.labels[index]

   def __len__(self):
      return len(self.labels)

# Let's assume we have image data in NumPy arrays
images = np.random.rand(10000, 3, 32, 32)
labels = np.random.randint(0, 10, 10000)

dataset = ImageDataset(images, labels)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)

我们在上述代码中创建了一个自定义的 Dataset 类。__len__ 函数返回图像的总数,而 __getitem__ 方法返回给定索引处的图像和标签。然后,DataLoader 将包装此 Dataset,它将处理批处理和数据随机打乱。

在 PyTorch 中使用 Transforms

您可以使用转换以灵活的方式预处理数据。例如,在基于图像的任务中,我们通常需要对数据进行归一化、将其转换为张量或使用数据增强技术。使用 PyTorch 的转换模块,这些任务变得非常简单。

from torchvision import transforms

# Define a transform to normalize the data
transform = transforms.Compose([
   transforms.ToTensor(),
   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Apply the transform to all images in the dataset
class ImageDataset(Dataset):
   def __init__(self, images, labels, transform=None):
      self.images = images
      self.labels = labels
      self.transform = transform

   def __getitem__(self, index):
      image = self.images[index]
      if self.transform:
         image = self.transform(image)
      return image, self.labels[index]

   def __len__(self):
      return len(self.labels)

dataset = ImageDataset(images, labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)

在此示例中,转换在归一化后将图像数据转换为 PyTorch 张量。当我们实例化我们的 ImageDataset 时,我们将此转换传递给它,然后它将在 '__getitem__' 方法中应用于每个图像。

从 CSV 文件加载数据

对于诸如回归分析和分类之类的操作,通常需要加载来自 CSV 文件的数据。让我们使用 pandas 加载 CSV 文件、处理数据并构建 PyTorch DataLoader。

import pandas as pd
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import TensorDataset

# Load the data from a CSV file
df = pd.read_csv('data.csv')

# Convert categorical data to numerical data
le = LabelEncoder()
df['category'] = le.fit_transform(df['category'])

# Split the data into inputs and targets
inputs = df.drop('category', axis=1).values
targets = df['category'].values

# Convert to PyTorch Dataset
dataset = TensorDataset(torch.from_numpy(inputs), torch.from_numpy(targets))

# Wrap in a DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

在此示例中,pandas 用于从 CSV 文件加载数据。然后,Scikit-Learn 中的 LabelEncoder 函数用于将分类数据转换为数值数据。输入和目标被分割,它们被转换为 PyTorch 张量,并创建了一个 TensorDataset。最后,我们创建了一个 DataLoader 来处理批处理和随机打乱。

结论

在 PyTorch 中创建有效的机器学习模型,数据加载是一项基本技能。使用 PyTorch 的 DataLoader、Dataset 和 Transform 类,这项工作变得更简单、更高效。无论您是在处理表格数据还是图像数据,都可以修改这些类以满足您的需求。

更新于: 2023-07-18

271 次查看

开启您的 职业生涯

通过完成课程获得认证

立即开始
广告