之前的辨识手写阿拉伯数字程序,还有以下缺点:
(1)使用MNIST的测试数据,辨识率达98%,但如果以绘图软件里使用鼠标书写的文件测试,辨识率就差很多。这是因为MNIST的训练数据与鼠标书写的样式有所差异,MNIST数据是请受测者先写在纸上,再扫描存盘,所以图像会有深浅不一的灰度和锯齿状,所以,如果要实际应用,还是须自行收集训练数据,准确率才会提升。
(2)若要自行收集数据,找上万个测试者书写,可能不太容易,又加上有些人书写字体可能会有歪斜、偏一边或大小不同,都会影响预测准确度,这时可以借由数据增强(Data Augmentation)技术,自动产生各种变形的训练数据,让模型更强健(Robust)。
数据增强可将一张正常图像转换成各种的图像,例如旋转、偏移、拉近/拉远、亮度等效果,将这些数据当作训练数据训练出来的模型,就能较好地辨识有缺陷的图像。
PyTorch提供的数据增强函数很多元,可参阅 TRANSFORMING AND AUGMENTING IMAGES [3] ,我们已在5-1节测试过,详情请参阅程序 【05_01_Datasets.ipynb】 。
范例1.将数据增强函数整合至【06_03_MNIST_CNN_Normalize.ipynb】中。
完整程序请参阅【06_05_Data_Augmentation_MNIST.ipynb】。
(1)程序几乎不需改变,只要在数据转换方式加上随机转换(Random*),注意,数据增强函数很多,但阿拉伯数字有书写方向,有些随机转换不可采用,例如水平转换(RandomHorizontalFlip),若使用的话,3就变成ε了。相关的效果可参阅PyTorch官网 Illustration of transforms [4] 。
训练数据要数据增强,测试数据不需要。
(2)以测试数据评分,准确率为98.54%,并未显著提高。
(3)测试自行书写的数字,原来的模型无法正确辨识笔者写的9,经过数据增强后,已经可以正确辨识了。注意使用不同套件,读出的像素值区间会有所不同。
使用torchvision.io.read_image读取文件,像素介于[0, 1],回传的数据类型为torch.Tensor。
使用PIL读取文件,像素介于[0, 255],回传的数据类型为image,须以np.array()转换成NumPy ndarray,才能进行运算,之后,再利用Image.fromarray()转换回Image。
使用scikit-image读取文件,像素介于[0, 1],回传的数据类型为NumPy ndarray。
(4)先使用PIL读取文件,测试自行书写的数字。
执行结果:准确率为100%。
(5)使用scikit-image读取文件,测试自行书写的数字。
执行结果:预测结果相同。
(6)自定义数据集:可与训练数据采取一致的转换,不易出错。若是以次目录名称为标注,可直接使用torchvision.datasets.ImageFolder,不必自定义数据集。
(7)使用自定义数据集预测。
执行结果:准确率为100%。
(8)验证:修改test函数,显示每一笔数据的实际值与预测值。
执行结果:
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device=’cuda:0’)
范例2.单色图像结果非常完美,我们进一步编写书写接口,试试看模型是否可以派上用场。
程序: cnn_desktop\main.py 。
(1)先复制【06_05_Data_Augmentation_MNIST.ipynb】程序产生的cnn_augmentation_model.pt模型至本程序所在目录。
(2)载入套件。
(3)加载模型。
(4)预测。
(5)复制模型结构,会发现使用Functional API的Class也需放入程序中,才能顺利使用模型。
(6)窗口接口使用Tkinter,细节请参考程序文件。
(7)执行python main.py:以鼠标书写后,单击“辨识”按钮,辨识结果就会出现在右下文字框中,测试结果非常好。
范例3.接着再试试CIFAR彩色图像的数据增强。
下列程序代码请参考【06_06_Data_Augmentation_CIFAR.ipynb】。
(1)数据转换加上水平翻转(RandomHorizontalFlip)、随机裁切(RandomCrop),其他转换也可以添加,不过CIFAR图像分辨率过低,而且侦测的对象均占满整个图片,因此添加其他转换似乎并无太大帮助。
(2)改用CIFAR数据集。
(3)训练模型:发现添加转换后,一个训练周期的数据量仍然是50000笔,并未增加,也就是产生了更多样化的数据,但并未随同原来的训练数据一起被取出训练,只是以转换后的增强数据取代原数据,因此,笔者增加训练周期(10 → 20),以增加数据被广泛抽中的概率,同时也调小学习率,希望以更小的步幅寻求最佳解。
执行结果:观察下图,损失尚未收敛,准确率为49.34%,反而下降了,原因应该是背景过于复杂、训练数据不足,就算再多的转换也无济于事。
笔者在网络上搜寻其他先进的做法,发现两篇文章可供参考:
(1) PyTorch Implementation of CIFAR-10 Image Classification Pipeline Using VGG Like Network [5] ,较复杂的VGG模型(后续会介绍),使用数据增强,分别训练40/80/120/160/300周期,发现训练160周期后,准确率可达90%以上,如再训练更多周期,则会产生过度拟合的现象,即验证数据的准确率会背离训练的准确率,如下图。
(2) How Data Augmentation Improves your CNN performance? [6] 一文使用ResNet模型(后续会介绍),训练15周期,未使用数据增强,准确率可达75%;使用数据增强,准确率可达83%。
从以上的试验可知,数据增强并不重要,较复杂的模型及更多训练周期,才是提高CIFAR辨识准确率的关键因素。
TensorFlow提供的数据增强功能效果比较明显,读者可以比较看看,另外还有其他的函数库,提供更多的数据增强效果,比如Albumentations [7] ,包含的类型多达70种,很多都是TensorFlow/PyTorch所没有的效果,例如下图的颜色数据增强。