购买
下载掌阅APP,畅读海量书库
立即打开
畅读海量书库
扫码下载掌阅APP

5-1 数据集及数据加载器

torch.utils.data.Dataset是PyTorch内建数据结构,可同时存储特征( x )及目标( y ),包含一些内建的数据集:

(1)影像数据集:例如MNIST、FashionMNIST等,可参考Pytorch官网torchvision.datasets [1]

(2)语音数据集:可参考Pytorch官网torchaudio.datasets [2]

(3)文字数据集:可参考Pytorch官网torchtext.datasets [3]

(4)除此之外,还可以自定义数据集。

范例.加载数据集,并读取相关数据。

下列程序代码请参考【05_01_Datasets.ipynb】。

(1)载入套件。

(2)检查是否有GPU。

(3)加载MNIST手写阿拉伯数字数据。MNIST等数据集都有5个参数。

根路径(root):数据集下载后存储的目录,空字符串表示目前文件夹。

train:True表示下载训练数据集;False表示下载测试数据集。

download:True表示数据集不存在则自网络下载;False表示不会自动下载。

transform:数据集读入后特征( x )要做何种转换,至少要转成PyTorch Tensor。各种转换可参考Pytorch官网torchvision.transforms [4]

target_transform:数据集读入后,目标( y )要做何种转换。

(4)读取数据:直接指定索引值,例如train_ds.data[0],即可读取第一笔数据。 注意,使用train_ds.data[0]读取数据并不会应用到Transform函数,必须使用DataLoader读取数据,Transform函数才会发生效果。

执行结果。

(5)再看另一个数据集FashionMNIST,同时说明数据转换(Transform)及自定义数据集(Custom Dataset)的用法。

(6)任意抽样9笔数据显示:labels_map是目标值与名称的对照。

执行结果。

(7)数据转换(Transform):PyTorch提供非常多的转换函数,包括转换成PyTorch Tensor、放大/缩小、剪裁、彩色转灰度、各种数据增强(Data Augmentation)的效果等,可减少数据前置处理的负担,TensorFlow目前缺乏类似的功能。我们先来看单张图片的转换,程序代码修改自Pytorch官网Illustration of transforms [5]

(8)读取范例图片文件:使用skimage套件内建的女航天员图像。

执行结果。

(9)转换输入须为Pillow格式,再以Pillow函数读取图片文件。

(10)定义绘图函数。

(11)图片放大/缩小。

执行结果:第1张为原图,之后为缩小成30%、50%、100%、原比例的图,可以看到缩小后再经ax.imshow放大,显示就变模糊了。

(12)自中心裁剪。

执行结果:第1张为原图,之后为裁剪成30%、50%、100%、原比例的图,可以看到以中心点为参考点,向外裁剪。

(13)FiveCrop:以左上、右上、左下、右下及中心点为参考点,一次裁剪5张图。

执行结果。

(14)转灰度。

执行结果。

(15)旁边补零:指定补零宽度为3、10、30、50。

执行结果:观察边框的宽窄。

总共超过20种转换,中文说明可参考《PyTorch学习笔记(三):transforms的二十二个方法》 [6] ,这些效果都可以任意组合至transforms.Compose函数内。

另外,处理图像时常会做特征缩放,在TensorFlow范例中会采取正规化(Normalization),公式为( x -min)/(max-min),使 x 的范围介于[0,1]之间,而PyTorch并未提供此转换,通常采用标准化,但却命名为Normalize,与Normalization有点混淆,公式为( x - μ )/ δ ,请特别注意。一般而言,标准化是假设 x 是常态分配,但像素颜色0~255,应该属均匀分配,采用正规化似乎比较合理,但PyTorch官网采用常态分配,我们在此不计较这些。

程序代码如下,含两组参数,第一组为RGB三色的平均数( μ ),第二组为RGB三色的标准偏差( δ ),这是从ImageNet大量数据集统计的结果:

transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

若图像为单色,程序代码如下:

transforms.Normalize((0.1307,), (0.3081,))

可参考程序 【04_07_手写阿拉伯数字辨识_Normalize.ipynb】

若要采取正规化,完整范例可参考程序 【05_02_手写阿拉伯数字辨识_ MinMaxScaler.ipynb】 ,辨识率不佳,可见PyTorch与TensorFlow在图像的细部处理上是有所差异的。

(16)接着,我们实操一个范例,并同时示范自定义数据集的做法,将一目录下的所有文件制作成数据集,并转换为正确的输入格式。

(17)先制作一个目标名称与代码的对照表,之后将文件名转换为目标代码。

(18)自定义数据集:自定义数据集类别必须包含三个方法:__init__、__len__、__getitem__,作用分别为初始化、总笔数、取得下一笔数据,这种方式不必一次加载所有图像,可以节省内存的耗用。

(19)加载【04_06_FashionMNIST辨识_完整版.ipynb】存储的模型。

(20)建立转换:依次转灰度、缩放、居中、转PyTorch Tensor。

(21)建立DataLoader:加载自定义数据集,进行测试。

执行结果:与【04_06_FashionMNIST辨识_完整版.ipynb】测试结果相同。

Dataset一次获取一笔数据,使用DataLoader则可以一次获取一“批”数据,方便我们做批量测试,加快训练及测试速度,参数如下:

第一个参数:Dataset。

batch_size:批量。

shuffle:读取数据前是否先洗牌。

不通过循环,一次获取一“批”数据。

执行结果:torch.Size([7, 28, 28]) tensor([8, 5, 5, 6, 0, 1, 1]),取出7笔数据(不足10笔)。

有关语音及文字数据集在后续章节再作介绍。 ao0cI78peUDmvMk3RtZDUt98m1OEQDitAXFut4RJwDGQBFSh4QuL8y3Q5ufNUpn3

点击中间区域
呼出菜单
上一章
目录
下一章
×