在 Pytorch 中加载数据
每个机器学习项目都依赖于数据,由 Facebook 创建的著名开源机器学习工具包 PyTorch 也不例外。本手册旨在简化 PyTorch 中的数据加载过程,并帮助您尽快开始使用。
本文将重点介绍 PyTorch 的 DataLoader、Dataset 和 Transform 类。我们将通过一些实际示例来帮助您理解这些 PyTorch 核心概念,并简化您的机器学习应用程序。
PyTorch 数据加载:简要概述
为了导入和准备数据,PyTorch 提供了一个强大且灵活的工具箱。三个关键要素是:
Dataset − 这个抽象类代表一个数据集,它允许以任何格式加载数据。只需要重写两个方法 __getitem__() 和 __len__()。
DataLoader − 它封装了一个 Dataset,并提供对底层数据的快速访问。它会自动构建批次、随机打乱数据,并使用多线程并行加载数据。
Transforms − 这些是常见的图像修改。可以通过 Compose 将转换链接在一起。这使您可以创建一个预处理操作管道,可以将其应用于加载的数据。
将数据加载到 PyTorch:示例
考虑一个图像集合,其中每个图像都表示为一个 3D NumPy 数组,并且标签与图像分开存储。以下是如何将此数据添加到 PyTorch 的快速方法。
from torch.utils.data import Dataset, DataLoader import numpy as np class ImageDataset(Dataset): def __init__(self, images, labels): self.images = images self.labels = labels def __getitem__(self, index): return self.images[index], self.labels[index] def __len__(self): return len(self.labels) # Let's assume we have image data in NumPy arrays images = np.random.rand(10000, 3, 32, 32) labels = np.random.randint(0, 10, 10000) dataset = ImageDataset(images, labels) dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
我们在上述代码中创建了一个自定义的 Dataset 类。__len__ 函数返回图像的总数,而 __getitem__ 方法返回给定索引处的图像和标签。然后,DataLoader 将包装此 Dataset,它将处理批处理和数据随机打乱。
在 PyTorch 中使用 Transforms
您可以使用转换以灵活的方式预处理数据。例如,在基于图像的任务中,我们通常需要对数据进行归一化、将其转换为张量或使用数据增强技术。使用 PyTorch 的转换模块,这些任务变得非常简单。
from torchvision import transforms # Define a transform to normalize the data transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Apply the transform to all images in the dataset class ImageDataset(Dataset): def __init__(self, images, labels, transform=None): self.images = images self.labels = labels self.transform = transform def __getitem__(self, index): image = self.images[index] if self.transform: image = self.transform(image) return image, self.labels[index] def __len__(self): return len(self.labels) dataset = ImageDataset(images, labels, transform=transform) dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
在此示例中,转换在归一化后将图像数据转换为 PyTorch 张量。当我们实例化我们的 ImageDataset 时,我们将此转换传递给它,然后它将在 '__getitem__' 方法中应用于每个图像。
从 CSV 文件加载数据
对于诸如回归分析和分类之类的操作,通常需要加载来自 CSV 文件的数据。让我们使用 pandas 加载 CSV 文件、处理数据并构建 PyTorch DataLoader。
import pandas as pd from sklearn.preprocessing import LabelEncoder from torch.utils.data import TensorDataset # Load the data from a CSV file df = pd.read_csv('data.csv') # Convert categorical data to numerical data le = LabelEncoder() df['category'] = le.fit_transform(df['category']) # Split the data into inputs and targets inputs = df.drop('category', axis=1).values targets = df['category'].values # Convert to PyTorch Dataset dataset = TensorDataset(torch.from_numpy(inputs), torch.from_numpy(targets)) # Wrap in a DataLoader dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
在此示例中,pandas 用于从 CSV 文件加载数据。然后,Scikit-Learn 中的 LabelEncoder 函数用于将分类数据转换为数值数据。输入和目标被分割,它们被转换为 PyTorch 张量,并创建了一个 TensorDataset。最后,我们创建了一个 DataLoader 来处理批处理和随机打乱。
结论
在 PyTorch 中创建有效的机器学习模型,数据加载是一项基本技能。使用 PyTorch 的 DataLoader、Dataset 和 Transform 类,这项工作变得更简单、更高效。无论您是在处理表格数据还是图像数据,都可以修改这些类以满足您的需求。