购买
下载掌阅APP,畅读海量书库
立即打开
畅读海量书库
扫码下载掌阅APP

3.1 大语言模型高效训练技术要点

本节只介绍大语言模型训练过程中的计算和内存效率的通用优化技术。这些技术是框架正交的,也就是不同框架中实现的独特的优化技术可以和这些技术直接结合使用,不互相干扰。这些技术在PyTorch以及Hugging Face的Transformers库中都已经实现。在下文中,提及内存默认是指GPU内存。如果是CPU内存,会显式指出。

1.混合精度训练 [1]

混合精度训练技术引入半精度表示fp16/bf16 ,在前向和反向传播过程中,模型参数以半精度表示,计算激活值和梯度时使用半精度,提升Tensor Core的吞吐量,从而提升训练速度。半精度表示在计算精度方面有所损失。对于需要保证高精度的操作,例如梯度截断、参数更新等,使用数值缩放技术可以减小精度损失的影响。

不过混合精度训练技术并没有带来内存效率的提升。在反向传播的最后时刻,为了保证参数更新的精度和效果,需要保存一份32位的优化器状态的副本。以Adam为例,其中包括32位的参数副本(有别于16位的参数副本)、32位的动量(momentum)副本和优化器状态(variance)副本。因此,使用Adam优化器训练一个参数量为| W |的模型,在训练过程中,参数、梯度和优化器状态的内存占用字节数为

2| W |+2| W |+3(4| W |)=16| W |

也就是说,整体内存占用相对于全精度训练保持不变。不过这至少也意味着,混合精度技术在提升计算效率的同时没有损失内存效率。因此,现在的大语言模型训练基本上都会开启混合精度训练。

2.小批次和梯度累积 [2]

前向和反向传播的单卡内存占用大小和批次大小成正比,这限制了单卡能处理的批次大小。同时,在分布式训练中,全局梯度计算涉及卡间/机间通信,这将降低计算效率。小批次技术将单卡上的一个批次划分为多个小批次,在一个个小批次上顺序计算梯度,在得到所有小批次的梯度后取平均值,再使用这个平均值参与全局梯度计算。通过这种方式,实现了在单卡上处理更大的批次,这样就可以在卡资源有限的情况下增大全局批次。同时,这样也减少了卡间/机间通信,提高了计算效率。

3.梯度/激活检查点 [3]

GPT系列模型在前向传播过程中产生的激活值(activation)数目的估算公式为

激活值= K ×批次大小×序列长度×隐含层大小×Transformer块数

其中, K =8+中间层大小/隐含层大小,中间层大小是FFN第一层全连接层的输出维度,隐含层大小等于嵌入大小。

这是一个巨大的数目,如果不做优化,将无法训练一个即使是中等参数规模的大语言模型。一个简单的想法是使用时间换空间的常规优化手段。激活值检查点(activation checkpointing)就是这样一种手段,它从Transformer块数目入手,不再缓存所有块的激活值,而是相隔几块缓存一次,这样将激活值的内存占用从 O N )降低为 O )。当然,这种优化将导致计算效率下降,计算时间将增加大约30%。鉴于内存效率的巨大优化,这个程度的计算效率下降是可以接受的。因此,现在的大语言模型训练基本上都开启激活值检查点优化。 xffoD2+fVLFMhkdCnNoQbDeLYNXCrVUmdK6NHedIttJX0IfAf7lCByksiwd7csg7

点击中间区域
呼出菜单
上一章
目录
下一章
×

打开