1、创建自定义 Dataset
torch.utils.data.Dataset
是 PyTorch 中表示数据集的类。我们可以继承该类并实现两个方法,__len__
返回数据集的大小,__getitem__
返回数据集中的一个样本。
import torch from torch.utils.data import Dataset, DataLoader class MyDataset(Dataset): def __init__(self, data, labels): self.data = data self.labels = labels def __len__(self): return len(self.data) def __getitem__(self, idx): # 返回数据和标签 sample = {'data': self.data[idx], 'label': self.labels[idx]} return sample # 示例数据 data = torch.randn(100, 3) # 100个样本,每个样本3个特征 labels = torch.randint(0, 2, (100,)) # 100个标签,0或1 # 创建数据集 dataset = MyDataset(data, labels)
2、使用 DataLoader 加载数据
DataLoader
是 PyTorch 中用于批量加载数据的类,它支持并行加载、数据打乱(shuffling)、以及批次化。
# 创建 DataLoader dataloader = DataLoader(dataset, batch_size=10, shuffle=True) # 遍历 DataLoader for batch in dataloader: data_batch = batch['data'] label_batch = batch['label'] print(data_batch, label_batch)
3、加载常见数据集(如 MNIST、CIFAR-10)
PyTorch 提供了多个常见数据集的封装,比如 torchvision.datasets
模块中包含了 MNIST、CIFAR-10 等数据集。这些数据集已经被封装为 Dataset
对象,方便直接使用。
import torch from torchvision import datasets, transforms from torch.utils.data import DataLoader # 定义数据预处理(转换为 Tensor 并标准化) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # 下载并加载训练集 train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) # 使用 DataLoader 加载数据 train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) # 遍历 DataLoader for images, labels in train_loader: print(images.shape, labels.shape)
4、使用其他数据集(如 CIFAR-10)
类似于 MNIST,也可以加载 CIFAR-10 数据集,处理方式完全相同。
from torchvision import datasets, transforms from torch.utils.data import DataLoader # 定义数据预处理(转换为 Tensor 并标准化) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # 下载并加载训练集 train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) # 使用 DataLoader 加载数据 train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) # 遍历 DataLoader for images, labels in train_loader: print(images.shape, labels.shape)
5、自定义数据预处理
可以对数据进行自定义的预处理操作,如裁剪、旋转、颜色变换等,这可以通过 transforms
来实现。
from torchvision import transforms from PIL import Image # 定义数据预处理(将图片缩放至 224x224,转换为 Tensor) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 加载一张图片并应用预处理 img = Image.open('path_to_image.jpg') img_tensor = transform(img)
6、加速数据加载(多线程)
PyTorch 的 DataLoader
支持多线程加载数据,可以通过设置 num_workers
参数来加速数据加载过程,尤其是在读取和处理大量数据时。
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)