购买
下载掌阅APP,畅读海量书库
立即打开
畅读海量书库
扫码下载掌阅APP

2.4 长上下文

支持长上下文意味着大模型可以更好地理解文本中的长程依赖(long-range dependency),可以在文本中距离很远的观点之间建立联系,从而生成全局更一致的文本输出。

2.4.1 采用RoPE位置编码的长上下文扩展

无论是Llama还是Llama 2都使用RoPE位置编码。本节讨论两种在RoPE位置编码方式下实现支持长上下文的两种具有代表性的方法:第一种方法实现简单,计算代价小,效果不错;第二种方法更为彻底,计算代价大,效果更好。

1.RoPE位置插值

在Hugging Face的Transformer库中,RoPE位置编码的实现方式如下:

cos,sin=self.rotary_emb(value_states,seq_len=kv_seq_len)
query_states,key_states=apply_rotary_pos_emb(
    query_states,key_states,cos,sin,position_ids
)

RoPE位置编码从values得到cos和sin波函数,将位置id转换为位置编码。

Kaiokendev团队 [14] 和Meta团队 [14] 各自独立发现的外推方法可以在不对模型进行继续预训练的情况下实现推理时支持数倍于预训练上下文长度的效果。具体实现可以参考LongChat [16] 的做法,主要步骤如下:

(1)压缩RoPE。

Llama模型是使用RoPE在序列长度2048上进行预训练的,这意味着在预训练阶段观察不到position_ids>2048的情况。研究团队没有强制Llama模型适应position_ids>2048,而是将position_ids>2048的部分压缩为0~2048。直观地说,研究团队假设这种压缩可以最大程度地重用在预训练阶段学到的模型权重。通过将目标新上下文长度 y 除以2048作为压缩率,然后将每个position_ids除以压缩率,并将其输入apply_rotary_pos_emb函数,代码如下:

query_states,key_states=apply_rotary_pos_emb(
    query_states,key_states,cos,sin,position_ids/ratio
)

在LongChat-16K中,研究团队将模型微调到上下文长度为16 384,将压缩率设为8。例如,把position_ids=10 000的词元变为position_ids=10000/8=1250,而相邻的下一个词元变为position_ids=10001/8=1250.125。

(2)微调精选的对话数据库。

在压缩RoPE之后,研究团队使用他们精心挑选的对话数据集执行微调过程。研究团队重新使用了先前用来训练Vicuna的用户分享对话数据,使用FastChat数据处理流程清理数据,截断这些对话,使其长度不超过16 000。然后再使用标准情况下一个词元的预测损失对模型进行微调。最后分别使用80 000个和18 000个对话对7B和13B模型进行微调。

2.RoPE位置编码修改+更长上下文继续预训练

这里介绍Meta研究团队的另一项上下文扩展的工作 [17] 。在这项工作中,研究团队对RoPE位置编码做了简单的修改,然后结合继续预训练和微调的方法达到对更长上下文的支持和理解能力。

通过7B规模的早期实验,Meta研究团队发现了Llama 2位置编码的一个关键限制,它阻止了注意力模块聚合远处词元的信息。研究团队对RoPE位置编码采用了最小但必要的修改以进行长上下文建模,即,减小旋转角度(由超参数基频b控制),从而减弱RoPE对长距离词元的衰减效应。

注意力操作的计算代价和输入序列长度的平方成正比,直接在预训练时使用长序列语料会带来显著的额外计算开销。所以,研究团队采用继续预训练的方法,在短序列语料上训练出来的预训练模型上使用长序列语料继续进行预训练,使用的语料规模是预训练的10%左右。要训练更长的上下文,一个重要的优化方向是提升注意力操作的计算和内存效率。为了进一步减少训练的计算代价,Meta研究团队采取了一种逐步加大输入序列长度的增量式训练方法。

此外,在微调阶段,需要解决长指令数据的构造问题。在这项工作中,Meta研究团队发现了一种简单而成本较低的方法:利用预先构建的大型且多样化的短提示数据集构造长指令数据集。实验结果显示,使用这种方法构造的长指令数据集微调过的模型在长上下文基准测试中效果非常好。具体来说,他们采用了Llama 2 Chat中使用的RLHF数据集,并使用了Llama 2 Chat本身生成的合成的self-instruct长指令对其进行增强。研究团队希望该模型能够通过大量RLHF数据学习到多种多样的技能,并通过self-instruct数据将这些知识转移到长上下文场景中。数据生成过程侧重于问答格式的任务:从预训练语料库中的长文档开始,随机选择一个文本块并提示Llama 2 Chat根据文本块中的信息编写问答对,由此收集到带有不同提示的长答案和短答案。然后,采取自我批评步骤,提示Llama 2 Chat验证模型生成的答案。给定生成的问答对,使用原始长文档(截断为模型的最大上下文长度)作为上下文构建训练实例。此外,为了提升训练效率,不浪费计算资源,在长短指令混合微调过程中,对短指令进行拼接处理。

2.4.2 注意力操作优化

长上下文支持的一个关键技术障碍是注意力操作的计算和内存代价与输入序列长度的平方成正比。因此,一个注意力操作的优化实现对高效训练至关重要。下面着重介绍斯坦福大学团队在这一方面的出色工作——FlashAttention [18]

无论是Llama还是Llama 2,都使用标准注意力计算的优化实现xformers [19] 。xformers是一个类似于FlashAttention的高效注意力计算的实现。不过,考虑到FlashAttention是一项更独立的第三方技术,更适合和各种网络架构集成,内存优化效率更高,而xformers更多地使用在Llama架构的模型中,这里以FlashAttention而非xformers作为示例进行讲解。

1.FlashAttention原理

只要学过计算机体系结构,就可以很轻松地理解FlashAttention的优化逻辑。GPU和CPU一样,存储也分为几个层级。GPU的HBM(High Bandwidth Memory,高带宽内存)和CPU的内存一样,存储容量比较大,但读写速度慢;GPU的SRAM和CPU的SRAM一样,存储容量小,但读写速度快。

注意力的计算主要涉及 Q K V 的矩阵乘法操作和 Q K 的softmax操作。在以前的注意力计算的实现中,计算的输入输出都是和HBM打交道,所以计算性能低,而且计算操作单元就是整个输入,占用SRAM空间大。那么,在GPU的这种层级存储结构体系下,如何利用优化技术提升上述两个操作的计算和内存效率?首先考虑计算效率优化。读写HBM速度慢,那么就尽量减少对它的读写。先来看一下优化之前标准注意力计算对HBM的读写情况。

这个算法要求输入的矩阵 Q K V 都是 R N × d 维的,并且存储在HBM中。这里的 R 表示实数集, N 是序列的长度, d 是特征的维度。 QK T 表示 Q K 的转置的矩阵乘法,softmax函数用于将 S 转换为概率分布,最后 PV 计算注意力加权的值。

可以看出,要计算出最终的 O ,需要进行3次计算: S = QK T P =softmax( QK T ), O = PV 。每次计算读写一次HBM,总共读写HBM 3次。3次计算都涉及整个矩阵的操作,中间结果太大,无法存储在SRAM中。因此,一个自然的想法是将矩阵 K V 分块,逐块操作,这样就可以将中间结果暂时存储在SRAM中,再通过一些运算变换,就可以在只读写一次HBM的情况下一次性将 O 计算出来。当然,虽然整个矩阵 K V 只读取一次,但是每读取一块都需要一次操作。图2 . 9展示了FlashAttention如何执行注意力计算。

图2.9 FlashAttention注意力计算

具体来说,这里涉及两个优化:一是平铺(tiling)和softmax重新缩放(rescaling),softmax的计算是逐块进行的而不是对整个矩阵进行的;二是重新计算(recomputing),在前向传播时,将softmax归一化因子存放在SRAM中,这样在反向传播时可以快速计算注意力。

FlashAttention实现的伪代码如下:

FlashAttention采用的逐块计算方式大大降低了SRAM的空间占用,使得预训练可以使用更长的上下文。

FlashAttention-2 [20] 在3方面做了改进,包括算法、并行性、工作切分方式等,比FlashAttention的性能提升了大约1倍。

(1)更好的算法。

FlashAttention-2对FlashAttention的算法进行了调整,减少了非矩阵乘法的数量。这对计算的加速非常关键。因为现代GPU有特殊的计算单元,例如Nvidia的GPU上的Tensor Cores,可以让矩阵乘法操作比其他操作快得多。例如,A100 GPU FP16/BF16矩阵乘法的最大理论吞吐量为312TFLOPS,而非矩阵乘法只有19.5TFLOPS。为了保持高吞吐量,需要将尽可能多的时间花在矩阵乘法上。FlashAttention-2的开发者重写了FlashAttention中使用的online softmax技巧以减少重新缩放操作以及边界检查和因果掩码(casual masking)操作的数量。

(2)更好的并行性。

FlashAttention是基于批次大小和注意力头数量进行并行操作的,使用一个线程块(thread block)处理一个注意力头,因此总共需要的线程块数是以上两个数的乘积。每个线程块调度运行在一个流式多处理器(Stream Multiprocessor,SM)上。当需要的线程数比较多,例如批次大小和注意力头数量比较大时,这种调度方式很高效,因此可以用到GPU上几乎所有的计算资源。不过,当预训练使用的上下文比较长时,通常只能处理比较小的批次,在这种情况下,在上下文长度上进行并行化操作是更高效的方式。

(3)更好的工作切分方式。

在每个线程块内如何切分工作,即,将每个 Q K V 块切分,在不同的Warp(一个Warp由一组线程组成)处理,减少Warp间的通信和同步次数,可以提升计算效率。图2.10展示了FlashAttention和FlashAttention-2前向传播时在不同Warp间切分工作的方式。

对于每个 Q K V 块,FlashAttention将 K V 切分成4份,在4个Warp中处理; Q 没有切分,所有Warp都可以存取。这种方式效率不高,因为所有的Warp都需要将中间结果写入共享内存,进行同步,然后将所有的中间结果加起来。这些对共享内存的读写操作减慢了前向传播中的注意力计算。在FlashAttention-2中,将 Q 切分成4份,在4个Warp中处理; K V 没有切分,所有Warp都可以存取。每个Warp执行完矩阵乘法后得到 QK T 的一个切片,只需要将它与共享的 V 的切片相乘,就可以得到相应的输出切片,不需要在Warp间进行通信。减少对共享内存的读写操作带来了计算的加速。

图2.10 FlashAttention和FlashAttention-2前向传播时在不同Warp间切分工作的方式

2.FlashAttention集成

FlashAttention相对于标准实现对注意力计算性能带来了数倍的提升。那么,如何在开源的大模型中集成FlashAttention以提升预训练的性能,从而支持更长上下文呢?下面这段代码展示了如何将Llama的注意力计算替换成FlashAttention的逻辑: xffoD2+fVLFMhkdCnNoQbDeLYNXCrVUmdK6NHedIttJX0IfAf7lCByksiwd7csg7

点击中间区域
呼出菜单
上一章
目录
下一章
×

打开