优化器是神经网络中反向传导的求解方法,着重在两方面:
(1)设定学习率的变化,加速求解的收敛速度。
(2)避开马鞍点(Saddle Point)等局部最小值,并且找到全局的最小值(Global Minimum)。
优化的过程如图4.13,随着训练的过程,沿着等高线逐步逼近圆心,权重不断更新,最终得到近似最佳解。
图4.13 随机梯度下降法(Stochastic Gradient Descent, SGD)求解图示
PyTorch支持很多种不同的优化器,可参阅官网中关于优化器介绍 [8] ,大部分都是动态调整的学习率,一开始离最佳解很远时,学习率可加大,越接近最佳解,学习率就越小,以免错过最佳解。常见的优化器如下:
· SGD;
· Adam;
· RMSprop;
· Adadelta;
· Adagrad;
· Adamax;
· Nadam;
· AMSGrad。
各种优化器的公式可参考 Gradient Descent Optimizers [15] 或 10 Stochastic Gradient Descent Optimisation Algorithms+Cheat Sheet [16] ,优缺点比较可参考 Various Optimization Algorithms For Training Neural Network [17] 。
范例.列举常用的优化器并进行测试。
下列程序代码请参考【04_14_优化器.ipynb】。
(1)随机梯度下降法(Stochastic Gradient Descent, SGD):是最常见、最单纯的优化器,语法为:
torch.optim.SGD(model.parameters(), lr,, momentum=0, dampening=0,
weight_decay=0, nesterov=False)
可以设定为:
model.parameters():模型的参数(权重)。
lr:学习率,为必填参数,在未设定其他参数时,学习率为固定值。
momentum:学习率变化速率的动能。
weight_decay:L2惩罚项的权重衰减率。
nesterov:是否使用Nesterov momentum,默认值是False。要了解技术细节可参阅“Understanding Nesterov Momentum (NAG)” [18] 。
第14行:建立随机梯度下降法(SGD)优化器。
第22行:优化器执行一个步骤,反向传导,更新权重。
(2)Adam(Adaptive Moment Estimation):是最常用的优化器,这里引用Kingma等学者于2014年发表的 Adam: A Method for Stochastic Optimization [19] 一文所作的评论“Adam计算效率高、内存耗费少,适合大数据集及参数个数很多的模型”。
Adam语法:torch.optim.Adam(model.parameters(), lr, betas, eps, weight_decay, amsgrad)
model.parameters():模型的参数(权重)。
lr:学习率,为必填参数,若未设定其他参数,学习率为固定值。
betas:计算平均梯度及其平方项的系数。
eps:公式分母的加项,以改善优化的稳定性。
weight_decay:L2惩罚项的权重衰减率。
amsgrad:是否使用AMSGrad,技术细节可参阅《一文告诉你Adam、AdamW、Amsgrad区别和联系》 [20] 。
(3)另外还有几种常用的优化器:
Adagrad(Adaptive Gradient-based optimization):设定每个参数的学习率更新频率不同,较常变动的特征使用较小的学习率,较少调整,反之,使用较大的学习率,比较频繁地调整,主要是针对稀疏的数据集。
RMSprop:每次学习率更新是除以均方梯度(average of squared gradients),以指数的速度衰减。
Adadelta:是Adagrad改良版,学习率更新会配合过去的平均梯度调整。
各种优化器会在一些比较特殊的状况下,突破马鞍点,顺利找到全局的最小值,一般情况下采用Adam及预设参数值即可,大致都可以达到梯度下降的效果。网络上也有许多优化器的比较和动画,有兴趣的读者可参阅 Alec Radford's animations for optimization algorithms [21] 。
不管是神经层、Activation Function、损失函数或优化器,Functional API都有对应的函数,都在torch.nn.functional命名空间内,与torch.nn无差别,根据我们要采取哪一类的模型而定。相关Functional API函数可参阅官网torch.nn.functional介绍 [10] 。