购买
下载掌阅APP,畅读海量书库
立即打开
畅读海量书库
扫码下载掌阅APP

1.1 Transformer的基础架构与原理

Transformer模型的核心在于其独特的多头注意力机制与网络稳定性设计,这一架构在处理自然语言任务中展现了卓越的建模能力。多头注意力机制通过对输入序列的并行计算,捕捉各词之间的复杂依赖关系,并通过查询、键和值的矩阵变换实现高效信息交互。同时,位置编码引入了序列中的位置信息,使得模型在没有循环结构的前提下具备捕捉顺序信息的能力。层归一化与残差连接进一步提升了网络的稳定性,确保深层结构的流畅传递和梯度的有效回传,为Transformer的深层学习奠定了基础。

本节主要介绍Transformer架构的多头注意力机制模块和位置编码设计的特点,并在此基础上讲解Transformer的核心原理。

1.1.1 多头注意力机制的核心计算

多头注意力机制是Transformer架构中的核心模块,其设计使模型能够关注输入序列中的不同位置,并并行地捕捉不同的语义关系。在这一机制中,每个输入词通过查询(Query)、键(Key)和值(Value)矩阵计算注意力分数。Transformer编码器结构与多头注意力机制分别如图1-1和图1-2所示。

图1-1 Transformer编码器架构图

图1-2 多头注意力机制示意图

多头注意力的计算可以看作多个注意力头的组合,每个头使用独立的权重矩阵对查询、键和值进行线性变换,从而在不同的语义维度上捕捉信息。接着,将各头的输出拼接起来,并通过线性变换生成最终输出。

以下代码将基于多头注意力机制的核心计算步骤实现完整示例,包括查询、键和值矩阵的生成,注意力分数的计算以及多头并行计算。

代码说明如下:

(1)定义多头注意力机制类MultiHeadAttention,其中包含查询、键和值的线性变换层和最终输出层。

(2)初始化时,将输入的嵌入维度按头数划分,每个头的维度为总嵌入维度除以头数。注意确保嵌入维度能够被头数整除。

(3)在forward方法中,首先根据头数重新调整查询、键和值的形状,使每个头拥有独立的嵌入表示。

(4)使用torch.einsum计算查询和键的点积,生成注意力分数矩阵,并应用softmax归一化得到每个位置的注意力权重。

(5)通过点积将注意力权重应用到值矩阵上,并将多个头的输出拼接起来,通过线性变换得到最终输出。

(6)在代码末尾初始化多头注意力层,并传入一个随机生成的输入数据进行测试,最终输出多头注意力机制的计算结果。

该实现将输出形状为(3, 10, 64)的多头注意力机制结果,表示经过8头并行注意力计算后的输出信息:

如果我们用生活中的例子来看待多头注意力机制,它就像几个朋友一起看电影,每个人关注的点不一样,有人观察剧情发展,有人注意演员表演,有人专注背景音乐。电影结束后,大家分享自己的感受,整合后对电影的理解比一个人单独看更深入。这种机制让模型能够同时捕捉到数据中的各种细节,理解更全面。

与单头注意力相比,多头注意力可以并行计算多个不同的注意力分布,能更高效地捕捉数据中的多种关系。如果只有一个头,模型的关注点可能过于单一,比如只看到主语和动词的关系,而忽略了修饰语的作用。多头注意力通过“分工合作”,确保模型能从不同角度分析数据,从而理解复杂结构的文本或其他输入。

多头注意力机制的关键在于“分头工作,最后汇总”。每个头独立计算注意力分布,通过关注不同的部分,捕捉数据的多样性,最后这些头的输出被整合,生成模型的最终输出。这种机制是Transformer成功的关键,正是它让模型能够灵活、高效地处理复杂的自然语言任务。

1.1.2 位置编码与网络稳定性的设计

在Transformer中,位置编码是一项关键设计,它为模型提供了序列顺序信息,弥补了无序的自注意力机制。位置编码通过将固定的位置信息添加到输入嵌入上,使模型能够在没有循环或卷积结构的情况下处理序列。常见的位置编码是采用正弦和余弦函数,通过不同频率的波形表示不同位置。

此设计为每个位置生成唯一的编码,确保模型能够学习到顺序依赖。此外,为了确保深层网络稳定,Transformer架构引入了层归一化和残差连接,保证梯度流动的稳定性和数据在层间的连贯性。Transformer完整架构如图1-3所示。

图1-3 Transformer完整架构图

通俗来说,Transformer就像一个聪明但没有记忆力的“听众”,虽然可以理解句子中每个单词的含义,但无法分辨这些单词在句子中的先后顺序。位置编码就像在每个单词上贴一个“编号”,告诉模型这个单词是第几个,以帮助它感知输入的结构。

位置编码的核心思想是为每个输入位置生成唯一的向量,这些向量与词嵌入(word embeddings)一起输入模型。固定位置编码公式如下:

其中, pos 是位置索引, i 是嵌入向量的维度索引, d 是嵌入向量的总维度。对于“猫吃老鼠”这样一句话,如果没有位置编码,Transformer可能认为“猫”和“老鼠”的关系是一样的,因为它只关注单词的内容,而忽略了顺序。有了位置编码,模型能够意识到“猫”在句子开头,“老鼠”在后面,从而正确理解句子含义。

以下代码将实现位置编码生成、层归一化以及残差连接。

代码说明如下:

(1)PositionalEncoding类生成位置编码矩阵,并将其添加到输入序列。位置编码通过不同频率的正弦和余弦函数,为模型引入位置信息。在forward方法中,位置编码被添加到输入序列上,使模型具备顺序意识。

(2)ResidualConnectionLayerNorm类实现了残差连接和层归一化,通过将子层输出与输入直接相加,并使用LayerNorm进行归一化,确保网络在深度增加的情况下保持稳定。forward方法对输入张量与子层输出执行残差连接。

(3)在代码末尾,首先初始化位置编码并将其应用于输入,随后初始化残差连接和层归一化模块,并对位置编码后的张量进行处理,最终输出结果。

代码运行结果如下: gs0vbU334Z6WZiyGy0IcdBge64kMFYl0zwSwzrMgKcxEhl58HOMX95W5U1Rm6MZi

点击中间区域
呼出菜单
上一章
目录
下一章
×