PyTorch 中,数据集(Dataset) 是深度学习训练流程中的基础组件,用于加载和处理数据。PyTorch 提供了灵活的方式来自定义或使用现成的数据集。本文主要介绍PyTorch中,torchvision.datasets及其各种类型。PyTorch包括以下数据集加载器:MNIST和COCO(字幕和检测)。

数据集包括下面给出的两种类型的大部分函数:

Transform:一个接收图像并返回标准内容的修改版本的函数。它们可以与转换组合在一起。

Target_transform:一个接受目标并对其进行转换的函数。例如,获取标题字符串并返回世界索引的张量。

常用模块:

模块/类作用
torch.utils.data.Dataset定义数据集的抽象基类,自定义数据集时需要继承它
torch.utils.data.DataLoaderDataset 包装成可迭代对象,支持批处理、打乱数据、多线程加载等功能

1、MNIST

以下是Mnist DataSet的示例代码:

dset.MNIST(root, train = TRUE, transform = NONE, 
target_transform = None, download = FALSE)

参数如下:

  • root:已处理数据所在的数据集的root目录。
  • train:True =训练集,False =测试集。
  • download:True =从互联网上下载数据集并将其放在root目录中。

2、COCO

COCO 是一个 大规模图像识别、分割和标注数据集,常用于目标检测、图像分割、图像字幕生成等任务。需要安装COCO API,示例使用PyTorch的数据集的COCO实现

import torchvision.dataset as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = ‘dir where images are’, 
annFile = ’json annotation file’,
transform = transforms.ToTensor())
print(‘Number of samples: ‘, len(cap))
print(target)

实现的输出如下:

Number of samples: 82783
Image Size: (3L, 427L, 640L)

3、常见内置数据集(torchvision.datasets

数据集说明
MNIST手写数字图像
CIFAR-10彩色图像分类(10类)
FashionMNIST衣服图像分类
ImageNet大型图像分类任务(需要单独下载)

推荐文档