1、创建自定义 Dataset
torch.utils.data.Dataset
是 PyTorch 中表示数据集的类。我们可以继承该类并实现两个方法,__len__
返回数据集的大小,__getitem__
返回数据集中的一个样本。
import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# 返回数据和标签
sample = {'data': self.data[idx], 'label': self.labels[idx]}
return sample
# 示例数据
data = torch.randn(100, 3) # 100个样本,每个样本3个特征
labels = torch.randint(0, 2, (100,)) # 100个标签,0或1
# 创建数据集
dataset = MyDataset(data, labels)
2、使用 DataLoader 加载数据
DataLoader
是 PyTorch 中用于批量加载数据的类,它支持并行加载、数据打乱(shuffling)、以及批次化。
# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
# 遍历 DataLoader
for batch in dataloader:
data_batch = batch['data']
label_batch = batch['label']
print(data_batch, label_batch)
3、加载常见数据集(如 MNIST、CIFAR-10)
PyTorch 提供了多个常见数据集的封装,比如 torchvision.datasets
模块中包含了 MNIST、CIFAR-10 等数据集。这些数据集已经被封装为 Dataset
对象,方便直接使用。
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义数据预处理(转换为 Tensor 并标准化)
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# 下载并加载训练集
train_dataset = datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
# 使用 DataLoader 加载数据
train_loader = DataLoader(train_dataset, batch_size=64,
shuffle=True)
# 遍历 DataLoader
for images, labels in train_loader:
print(images.shape, labels.shape)
4、使用其他数据集(如 CIFAR-10)
类似于 MNIST,也可以加载 CIFAR-10 数据集,处理方式完全相同。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义数据预处理(转换为 Tensor 并标准化)
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 下载并加载训练集
train_dataset = datasets.CIFAR10(root='./data',
train=True, download=True, transform=transform)
# 使用 DataLoader 加载数据
train_loader = DataLoader(train_dataset, batch_size=64,
shuffle=True)
# 遍历 DataLoader
for images, labels in train_loader:
print(images.shape, labels.shape)
5、自定义数据预处理
可以对数据进行自定义的预处理操作,如裁剪、旋转、颜色变换等,这可以通过 transforms
来实现。
from torchvision import transforms
from PIL import Image
# 定义数据预处理(将图片缩放至 224x224,转换为 Tensor)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 加载一张图片并应用预处理
img = Image.open('path_to_image.jpg')
img_tensor = transform(img)
6、加速数据加载(多线程)
PyTorch 的 DataLoader
支持多线程加载数据,可以通过设置 num_workers
参数来加速数据加载过程,尤其是在读取和处理大量数据时。
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)