购买
下载掌阅APP,畅读海量书库
立即打开
畅读海量书库
扫码下载掌阅APP

2.2 利用TensorFlow训练CIFAR-10识别模型

在第2.1节中,读者已经对CIFAR-10数据集和TensorFlow中CIFAR-10数据集的读取有了基本的了解。本节将以TensorFlow为工具,训练CIFAR-10的图像识别模型。

2.2.1 数据增强

1.数据增强的原理

深度学习通常会要求拥有充足数量的训练样本。一般来说,数据的总量越多,训练得到的模型的效果就会越好。

在图像任务中,通常会观察到这样一种现象:对输入的图像进行一些简单的平移、缩放、颜色变换,并不会影响图像的类别。图 2-14 所示为翻转了位置的汽车图像,并适当降低了对比度和亮度,得到的图像当然还是汽车。 它们都可以被用作是汽车的训练样本。

对于图像类型的训练数据,所谓的数据增强(Data Augmentation)方法是指利用平移、缩放、颜色等变换,人工增大训练集样本的个数,从而获得更充足的训练数据,使模型训练的效果更好。

图2-14 图像数据增强的示例

常用的图像数据增强的方法如下。

●平移:将图像在一定尺度范围内平移。

●旋转:将图像在一定角度范围内旋转。

●翻转:水平翻转或上下翻转图像。

●裁剪:在原有图像上裁剪出一块。

●缩放:将图像在一定尺度内放大或缩小。

●颜色变换:对图像的RGB颜色空间进行一些变换。

●噪声扰动:给图像加入一些人工生成的噪声。

使用数据增强方法的前提是, 这些数据增强方法不会改变图像的原有标签。 例如在MNIST数据集中,如果使用数据增强,就不能使用旋转180°的方法,因为标签为“6”的数字在旋转180°后会变成“9”。

2.TensorFlow中数据增强的实现

训练CIFAR-10识别模型用到了数据增强来提高模型的性能。实验证明,使用数据增强可以大大提高模型的泛化能力,并且能够预防过拟合。

实现数据增强的代码在cifar10_input.py的distorted_inputs()函数中,几行代码如下:

原始的训练图片是reshaped_image。最后会得到一个数据增强后的训练样本distorted_image。从reshaped_image到distorted_image的处理步骤如下:

●第一步是对reshaped_image进行随机裁剪。原始的CIFAR-10图像的尺寸是 32×32。随机裁剪出 24×24 的小块进行训练。因为小块可以取在图像的任何位置,所以仅此一步就可以大大增加训练集的样本数目。

●第二步是对裁剪后的小块进行水平翻转。每张图片有50%的概率被水平翻转,还有50%的概率保持不变。

●最后对得到的图片进行亮度和对比度的随机改变。

训练时,直接使用distorted_image进行训练即可。

2.2.2 CIFAR-10识别模型

与MNIST识别模型一样,得到数据增强后的图像distorted_image后,需要建立一个模型将图像识别出来。建立模型的代码在 cifar10.py 文件的inference()函数中。这个函数的代码如下:

模型的代码虽然比较复杂,但本质是不变的,与第 1.2.2 节中的手写体识别模型类似,都是输入图像,输入图像对应到各个类别的 Logit。这里使用了两层卷积层,还在卷积层后面额外加了三层全连接层。

2.2.3 训练模型

用下列命令就可以训练模型:

--data_dir cifar10_data/的含义是指定 CIFAR-10 数据的保存位置。--train_dir cifar10_train/的作用是另外指定一个训练文件夹。训练文件夹的作用是保存模型的参数和训练时的日志信息。

训练模型时,屏幕上会显示日志信息,如:

日志信息告诉我们当前的时间和已经训练的步数,还会显示当前的损失是多少(如 loss=3.98)。理想的损失应该是一直下降的。日志里最后括号里的信息表示训练的速度。这里的日志信息是在 GPU 下训练时输出的,如果读者使用CPU进行训练,那么训练速度会比这里慢。

2.2.4 在TensorFlow中查看训练进度

在训练的时候,常常想知道损失的变化,以及各层的训练状况。TensorFlow提供了一个可视化工具TensorBoard。使用TensorBoard可以非常方便地观察损失的变化曲线,还可以观察训练速度等其他日志信息,达到实时监控训练过程的目的。

要使用 TensorBoard,请打开另一个命令行窗口,切换到当前目录,并输入以下命令:

TensorBoard 默认在 6006 端口运行。打开浏览器,输入地址http://127.0.0.1:6006(或 http://localhost:6006),就可以看到 TensorBoard 的主页面,如图2-15所示。

图2-15 TensorBoard的主页面

单击total_loss_1,就可以看到loss的变化曲线,变化曲线会根据时间实时变动,非常便于实时监测。还可以滑动左侧工具栏中的“Smoothing”滑条,它的功能是平滑损失曲线,方便更好地观察损失曲线的整体变化情况。

单击 learning_rate,可以监控学习率的变化。观察学习率时,应当把“Smoothing”滑条拖曳至 0,因为学习率的值是确定的,并不存在噪声,因此也不需要进行平滑处理。

图2-16展示了训练到约60万步时(此时运行训练程序终端的代码应该打出类似step 600000的日志),损失和学习率的变化情况。

从图中可以看出,在深度模型的训练中,通常先使用比较大的学习率(如0.1),这样可以帮助模型在初期以比较快的速度收敛。之后再逐步降低学习率(如降低到 0.01 或 0.001)。在 CIFAR-10 识别模型的训练中,学习率从0.1开始递减,依次是0.01,0.001,0.0001。每一次递减都可以让损失更进一步地下降。

图2-16 训练到约60万步时损失和学习率的变化情况

除了上述功能外,在TensorBoard中还可以监控模型的训练速度。展开global_step 选项卡,对应的图形为每秒训练步数的情况。如图 2-17 所示,每秒大概训练 8~11 步,变化不是特别大。在实际训练过程中,如果训练速度发生较大的变化,或者出现训练速度随程序运行而越来越慢的情形,就可能是程序中出现了错误,需要进行检查。

图2-17 训练速度的变化情况

最后,简要介绍TensorBoard显示训练信息的原理。在指定的训练文件夹 cifar10_train 下,可以找到一个以 events.out 开头的文件。实际上,在训练模型时,程序会源源不断地将日志信息写入这个文件中。运行TensorBoard时只要指定训练文件夹,TensorBoard 会自动搜索到这个文件,并在网页中显示相应的信息。

2.2.5 测试模型效果

在训练文件夹cifar10_train/下,还会发现一个checkpoint文件和一些以model.ckpt 开头的文件。TensorFlow 会将训练得到的模型参数保存到“checkpoint”里。在训练程序中,已经设定好每隔10min保存一次checkpoint,并且只保留最新的5个checkpoint,保存时如果已经有了5个checkpoint就会删除最旧的那个。

用记事本打开checkpoint文件,会发现类似如下的内容:

其中,model_checkpoint_path表示最新的模型是model.ckpt-601652(由于训练步数的不同,读者看到的数字可能和本书的有所不同)。601652表示这是第601652步的模型。后面的5个all_model_checkpoint_paths分别表示所有存储下来的5个模型和它们的步数。

使用cifar10_eval.py可以检测模型在CIFAR-10测试数据集上的准确性。

在命令行中运行代码:

--data_dir cifar10_data/表示 CIFAR-10 数据集的存储位置。--checkpoint_dir cifar10_train/则表示程序模型保存在cifar10_train/文件夹下。这里还用--eval_dir cifar10_eval/指定了一个保存测试信息的文件夹,测试时获得的结果(如准确率)会保存在cifar10_eval/中。

测试时要注意使用的是CPU还是GPU,总的来说有以下三种情况。

第一种情况是训练和测试都使用GPU。此时要注意不能在同一个GPU上运行命令,最好用另一个GPU进行测试,否则可能会由于显存不足,导致程序运行失败。使用另一张显卡的方法是设置不同的CUDA_VISIBLE_DEVICES环境变量。比如在训练时,先运行export CUDA_VISIBLE_DEVICES=0,再执行训练代码,这样训练程序只会使用 0号GPU。测试时,先运行export CUDA_VISIBLE_DEVICES=1,这样测试程序就会使用1号GPU。

第二种情况是使用GPU训练,用CPU测试。这种情况在测试时可以在命令行运行:export CUDA_VISIBLE_DEVICES=“”。这样测试程序将只会使用CPU进行测试,不会影响训练的GPU。

第三种情况是使用CPU进行训练和测试。此时如果系统没有设置GPU,那么直接运行相应的代码即可。

运行测试代码后,程序会立刻检测在最新checkpoint上的准确率。此外,它还会每隔一段时间自动执行一次,获取新保存的模型的准确率,并把所有信息写入文件夹cifar10_eval/中。

使用TensorBoard可以观察准确率随训练步数的变化情况。运行:

TensorBoard默认在6006端口运行,但这里使用--port 6007可以使它在6007端口运行。这是为了防止和之前运行的监控训练状况的TensorBoard发生端口冲突。打开http://127.0.0.1:6007,展开“Precision@1”选项卡,就可以看到准确率随训练步数变化的情况,如图2-18所示。

图2-18 模型的准确率随训练步数的变化情况

实际上到6万步左右时,模型就有了86%的准确率,到10万步时的准确率为86.3%,到15万步后的准确率基本稳定在86.6%左右。 bqhPaIsVzYxHOBeVMiEOTzio7isGyZ58v8eKvSDzaI+UWXzySa0Bhv22Cs2/7c3u

点击中间区域
呼出菜单
上一章
目录
下一章
×