购买
下载掌阅APP,畅读海量书库
立即打开
畅读海量书库
扫码下载掌阅APP

2.3 懒人的福音
——Keras模型库

TensorFlow官方使用Keras作为高级接口的额外一个好处就是可以使用大量已编写好的模型作为一个自定义层而直接使用,不需要使用者亲手对模型进行编写。

举例来说,一般常用的深度学习模型,例如VGG和ResNet(重点模型!后面章节会完整详细地介绍)等,可以直接从tf.keras.applications这个模型下直接导入。图2.15列出了Keras自带的模型数目。

图2.15 Keras自带的模型数目

可以看到,对于大多数的图像处理模型,applications模块都已经将其打包到内部可以直接调用。本章将以ResNet50为例,详细地介绍直接TensorFlow中预定义的ResNet模型的调用和参数的载入方式,但是具体使用将在第6章讲解。

2.3.1 ResNet50模型和参数的载入

首先是模型的载入,笔者选择ResNet50模型作为载入的目标,即将图2.15中倒数第4个模型作为载入,导入代码如下:

     resnet = tf.keras.applications.ResNet50()  #(载入可能卡住,下文有解决办法)

如果是第一次载入这个模型,就会在终端里显示如图2.16所示的信息。

图2.16 第一次载入

这是因为第一次载入时,Keras在载入模型的同时将模型默认参数下载并载入,可能会由于网络原因卡住,因此模型终端有可能在此停止运行。解决的办法非常简单,使用下载工具将蓝色部分下载下来,之后显式地告诉Keras参数的位置即可,代码如下:

这里weight函数显式地告诉模型所需要载入的参数位置。

注意

由于是显式地引入参数地址,因此需要写成绝对地址。

下面看一下ResNet50模型在Keras中的源码定义,代码如图2.17所示。

图2.17 ResNet50模型的源码定义

这里classes参数是ResNet基于imagenet数据集预训练的分类数,一般而言,使用预训练模型是用作特征提取而不是完整地使用模型作为同样的“分类器”,因此直接屏蔽掉最上面一层的分类层即可,代码可以改成如下:

使用summary函数可以将ResNet50模型的结构打印出来,如图2.18所示。

图2.18 ResNet50模型的结构

可以看到这里的模型最后几层的名称和参数多少,这是已经载入模型参数后的模型结构。

可能有读者对include_top=False这个参数设置有疑问,实际上笔者在这里做的是基于已训练模型为基础的“迁移学习”任务。迁移学习是将已训练模型去掉最高层的顶端输出层作为新任务的特征提取器,即这里利用“imagenet”预训练的特征提取方法迁移到目标数据集上,并根据目标任务追加新的层作为特定的“接口层”,从而在目标任务上快速、高效地学习新的任务。

【程序2-12】

2.3.2 使用ResNet50作为特征提取层建立模型

下面使用ResNet50作为特征提取层建立一个特定的目标分类器,这里简单地进行二分类的分类。代码如下(讲解在代码后部):

【程序2-13】

一般来说,预训练的特征提取器放在自定义的模型第一层,主要是用作对数据集的特征提取,之后的全局池化层是对数据维度进行压缩,将4维的数据特征重新定义成2维,从而将特征从[batch_size,7,7,2048]降维到[batch_size,2048],读者可以自行打印查看。

Drop_out_layer是屏蔽掉某些层用作防止过拟合的层,而fc_layer是用作对特定目标的分类层,这里通过设置unit参数为2定义分类成2个类。

最后一步是对定义的各个层进行组合:

     Binary_classes =
tf.keras.Sequential([resnet_layer,flatten_layer,drop_out_layr,fc_layer])

Sequential函数将各个层组合成一个完整的模型,打印的模型结构如图2.19所示。

图2.19 组合成一个完整的模型

不同于直接对ResNet50预训练模型的结构,这里仅仅将ResNet50当成了一个自定义的层来使用,因此可以看到在结构打印上这里依次显示了各个层的名称和参数,最下方是模型参数的总数。

下面还有一个问题是关于参数的,可以看到基本上所有的参数都是可训练的,也就是在模型的训练过程中所有的参数都参与了计算和更新。对于某些任务来说,预训练模型的参数是不需要更新的,因此可以对ResNet50模型进行设置,代码如下:

【程序2-14】

相对于上一个代码段,这里额外设置了 resnet_layer.trainable = False, 显式地标注了resnet为不可训练的层,因此resnet的参数在模型中不参与训练。

这里有一个小技巧:通过模型的大概描述比较参数的训练多少,显示结果如图2.20所示。

图2.20 模型展示

从图2.20可以看到,这里Non-trainable的参数占了大部分,也就是resnet模型参数不参与训练。读者可以自行比较。

注意

在使用ResNet模型做特征提取器的时候,由于Keras中的ResNet50模型是使用imagenet数据集做的预训练模型,输入的数据最低为[224,224,3],因此如果使用相同的方法进行预训练模型的自定义,那么输入的数据维度最小要为[224,224,3]。

其他模型的调用请有兴趣的读者自行完成。 x+MqvW8JIeJY2UKXisBJQEj9Km9emB975IMCmY50VBkWMTCGveW53j8R0qxj335n

点击中间区域
呼出菜单
上一章
目录
下一章
×