PyTorch 中,加载数据是深度学习工作流程中的重要步骤。PyTorch 提供了多种方式来加载数据,尤其是通过 torch.utils.data 中的 DataLoader 类,配合 Dataset 类来高效地加载和处理数据集,并在训练过程中应用各种数据预处理和增强策略。

1、创建自定义 Dataset

torch.utils.data.Dataset 是 PyTorch 中表示数据集的类。我们可以继承该类并实现两个方法,__len__ 返回数据集的大小,__getitem__ 返回数据集中的一个样本。

import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # 返回数据和标签
        sample = {'data': self.data[idx], 'label': self.labels[idx]}
        return sample

# 示例数据
data = torch.randn(100, 3)  # 100个样本,每个样本3个特征
labels = torch.randint(0, 2, (100,))  # 100个标签,0或1

# 创建数据集
dataset = MyDataset(data, labels)

2、使用 DataLoader 加载数据

DataLoader 是 PyTorch 中用于批量加载数据的类,它支持并行加载、数据打乱(shuffling)、以及批次化。

# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

# 遍历 DataLoader
for batch in dataloader:
    data_batch = batch['data']
    label_batch = batch['label']
    print(data_batch, label_batch)

3、加载常见数据集(如 MNIST、CIFAR-10)

PyTorch 提供了多个常见数据集的封装,比如 torchvision.datasets 模块中包含了 MNIST、CIFAR-10 等数据集。这些数据集已经被封装为 Dataset 对象,方便直接使用。

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义数据预处理(转换为 Tensor 并标准化)
transform = transforms.Compose([transforms.ToTensor(), 
transforms.Normalize((0.5,), (0.5,))])

# 下载并加载训练集
train_dataset = datasets.MNIST(root='./data', train=True, 
download=True, transform=transform)

# 使用 DataLoader 加载数据
train_loader = DataLoader(train_dataset, batch_size=64, 
shuffle=True)

# 遍历 DataLoader
for images, labels in train_loader:
    print(images.shape, labels.shape)

4、使用其他数据集(如 CIFAR-10)

类似于 MNIST,也可以加载 CIFAR-10 数据集,处理方式完全相同。

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义数据预处理(转换为 Tensor 并标准化)
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 下载并加载训练集
train_dataset = datasets.CIFAR10(root='./data',
train=True, download=True, transform=transform)

# 使用 DataLoader 加载数据
train_loader = DataLoader(train_dataset, batch_size=64,
shuffle=True)

# 遍历 DataLoader
for images, labels in train_loader:
    print(images.shape, labels.shape)

5、自定义数据预处理

可以对数据进行自定义的预处理操作,如裁剪、旋转、颜色变换等,这可以通过 transforms 来实现。

from torchvision import transforms
from PIL import Image

# 定义数据预处理(将图片缩放至 224x224,转换为 Tensor)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])
])

# 加载一张图片并应用预处理
img = Image.open('path_to_image.jpg')
img_tensor = transform(img)

6、加速数据加载(多线程)

PyTorch 的 DataLoader 支持多线程加载数据,可以通过设置 num_workers 参数来加速数据加载过程,尤其是在读取和处理大量数据时。

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)

推荐文档