购买
下载掌阅APP,畅读海量书库
立即打开
畅读海量书库
扫码下载掌阅APP

4.1 从状态空间模型SSM到结构化状态空间模型S4

我们的目标是实现对长序列的高效建模。为此,我们将基于状态空间模型(SSM)构建一个新的神经网络层。在本节结束时,我们将能够使用这一层来构建并运行模型。

前面提到,状态空间模型是动态系统的一种数学表达方式,能够很好地捕捉系统在时间上的演变规律。在神经网络中,我们可以利用状态空间模型来构建一个特殊的层,以实现对长序列数据的高效处理。这一层能够学习并模拟序列数据中的动态行为,从而提高模型的预测能力。

结构化状态空间模型(Structured State Space for Sequence,S4)创新性地融合了前面章节中阐述的状态空间模型、离散化技巧以及HiPPO初始化方法。通过这一综合,S4模型得以诞生,它不仅继承了状态空间模型的优点,还在高效处理长序列方面展现出卓越性能。特别是,S4模型中融入了HiPPO技术,使其在处理长距离依赖关系时表现尤为出色。结构化状态空间模型如图4-1所示。

本节将完成从基本的SSM到S4模型的讲解和组合,并通过PyTorch编程实现其中的部分内容,从而完成建模任务。

图4-1 结构化状态空间模型

4.1.1 从状态空间模型SSM开始(PyTorch具体实现)

在构建这个新的神经网络层之前,我们需要回忆一些关键概念和技术细节。首先,我们要明确状态空间模型的基本原理,包括状态方程和输出方程的定义及作用。状态方程描述了系统内部状态随时间的演变,而输出方程则将系统内部状态与外部观测联系起来。

接下来,我们将详细介绍如何在神经网络中实现这一层。具体来说,我们会探讨如何定义状态变量、状态转移矩阵以及输出矩阵,并通过训练来优化这些参数。此外,我们还会讨论如何将该层与其他神经网络层相结合,以构建一个完整的深度学习模型。

状态方程: h '( t )= Ah ( t )+ Bx ( t )。

输出方程: y ( t )= Ch' ( t )。

状态空间公式如上所示,它将一维输入信号 x ( t )经由 h ( t )变换后,投射到输出信号 y ( t )上。下面是一个随机生成状态空间系数SSM的函数,代码如下:

def random_SSM(N):
   A = torch.randn(size=(N,N))
   B = torch.randn(size=(N,1))
   C = torch.randn(size=(1,N))
   return A,B,C

接下来,我们需要对连续序列进行离散化处理。为了应用于离散的输入序列 x ( t )而非连续数据,状态空间模型必须通过步长 Δ 进行离散化,该步长代表输入的分辨率。从概念上讲,输入 x ( t )可以被视为对隐含的底层连续信号的采样。

通过这种离散化过程,我们能够将连续的状态空间模型转换为适用于离散序列数据的模型。这使得SSM在处理实际应用中的离散数据时更加灵活和实用。离散化后的SSM能够捕捉到序列数据中的动态变化,并通过学习状态转移矩阵和观测矩阵来模拟系统的行为。

为了将SSM应用于离散序列,我们需要选择合适的步长 Δ ,以确保模型能够准确地捕捉到输入数据的特征。步长 Δ 的选择应根据具体应用场景和数据特性进行,以保证模型的精度和效率之间的平衡。

4.1.2 连续信号转换为离散信号的PyTorch实现

状态空间模型中的离散化公式是微分方程描述,但计算机无法直接处理连续的信号,因此需要将系统离散化。

离散化的目标是将连续时间系统转换为离散时间系统,从而方便在计算机上进行数字模拟与分析。为实现这一目标,我们采用了一些关键的数学工具,这些工具能将连续时间系统转换成离散时间系统。通过这些公式,我们可以轻松地在计算机上模拟并分析连续动态系统的行为,如图4-2所示。

图4-2 连续信号转换为离散信号

为了应用于离散输入序列 x ( t )而非连续函数,状态空间模型必须通过步长 Δ step进行离散化,该步长代表了输入的分辨率。从概念上讲,输入 x ( t )可以被看作是对隐含的底层连续信号的采样。

为了离散化连续时间SSM,我们使用双线性变换方法。具体来说,从矩阵 A 的转换如下:

● 给定的离散化公式是将连续时间SSM的参数矩阵 A B C 转换为对应的离散时间SSM的参数矩阵

● 公式中的 Δ 表示采样周期,即离散化时的时间步长。

I 代表单位矩阵,其大小与 A 矩阵相同。

的公式采用了双线性变换,它是一种将连续时间系统的 s 平面映射到离散时间系统的 z 平面的方法。双线性变换能够保持系统的稳定性,并且在变换过程中引入了预畸变,以补偿离散化带来的误差。

的计算公式考虑了采样周期 Δ 对输入矩阵 B 的影响,确保在离散时间系统中能够正确地反映连续时间系统中的输入关系。

对于 ,它通常等于 C ,因为输出矩阵 C 描述了系统状态到输出的线性关系,这种关系在离散化过程中通常保持不变。

具体的PyTorch实现代码如下:

import torch

def discretize(A, B, C, step):
   I = torch.eye(A.size(0), dtype=A.dtype, device=A.device)  # 创建单位矩阵
   BL = torch.inverse(I - (step / 2.0) * A)   # 计算逆矩阵
   Ab = BL.mm(I + (step / 2.0) * A)           # 矩阵乘法
   Bb = (BL * step).mm(B)  # 注意这里先对BL的每个元素乘以step,再进行矩阵乘法
   return Ab, Bb, C

4.1.3 离散信号循环计算的PyTorch实现

该方程现在是一个从序列到序列的映射 x ( t )→ y ( t ),而不再是函数到函数的映射。此外,状态方程现在是 h ( t )的递推式,这使得离散状态空间模型(SSM)可以像循环神经网络(RNN)一样进行计算。具体来说, x ( t )∈ R N 可以被视为具有转移矩阵 的隐藏状态。

进一步来说,我们可以将离散SSM视为一种特殊的RNN,其中隐藏状态 h '( t )在每个时间步都根据输入数据 x ( t )和前一个隐藏状态 h ( t )进行更新。这种更新过程通过应用转移矩阵 和输入矩阵 来实现,同时还会考虑到可能的控制输入。然后,输出 y ( t )是通过将观察矩阵 应用到隐藏状态 h ( t )上获得的。

我们通过PyTorch中的for循环完成计算,代码如下:

def scan_SSM(Ab, Bb, Cb, x_tensor, hidden):
   h_t = hidden.unsqueeze(1)  # 确保h是一个列矩阵
   y_seq = []
   for x_t in x_tensor:
      x_t = x_t.unsqueeze(0)  # 确保u_k也是一个列矩阵
      h_t = torch.mm(Ab, h_t) + torch.mm(Bb, x_t)
      y_t = torch.mm(Cb, h_t)
      y_seq.append(y_t.squeeze(0))    # 移除额外的维度以添加到序列中
   return torch.stack(y_seq, dim=0)   # 将所有输出堆叠成一个序列

因此,离散SSM不仅能够捕捉动态系统的内部状态,还能根据这些状态生成相应的输出序列。这使得离散SSM在处理时间序列数据、控制系统建模以及预测等问题上非常有用。同时,由于其与RNN的相似性,我们可以利用深度学习框架来高效地实现和训练离散SSM,从而进一步拓展其应用范围。

4.1.4 状态空间模型SSM的PyTorch实现

为了整合前面介绍的状态空间模型的解释,作者定义了一个完整的状态空间模型(SSM)的PyTorch实现,并通过一系列函数来模拟这个模型的行为。

● random_SSM(N):用于生成SSM的三个主要矩阵:状态转移矩阵 A 、输入矩阵 B 以及输出矩阵 C 。这些矩阵的尺寸由参数 N (状态空间的维度)来决定,并且它们的元素都是从标准正态分布中随机抽取的。

● discretize(A, B, C, step):用于将连续的SSM转换为离散的SSM。这是通过将连续时间SSM的矩阵进行离散化处理来实现的,以便能够在离散的时间步长上模拟系统的行为。这个函数首先创建一个单位矩阵 I ,然后计算( I -(step/2) * A )的逆矩阵。接着,使用这个逆矩阵来计算离散化后的状态转移矩阵 A b 和输入矩阵 B b ,而输出矩阵 C 保持不变。

● scan_SSM(Ab, Bb, Cb, x_tensor, hidden):该函数模拟了离散SSM在给定的输入序列x_tensor下的行为。它初始化隐藏状态 h ( t )(即系统的内部状态),然后在每个时间步上根据离散化的SSM方程更新这个状态,并计算出对应的输出 y ( t )。这个过程是通过循环遍历输入序列的每个元素来完成的,最后将所有输出堆叠成一个序列并返回。

● run_SSM(A, B, C, x_tensor):该函数是整个模拟过程的主函数。它首先调用discretize函数来获取离散化的SSM矩阵,然后初始化隐藏状态为0向量,并调用scan_SSM函数来模拟SSM的行为。最后,返回模拟的输出序列。

完整代码如下:

在程序执行部分,代码首先设置了随机种子以确保结果的可重复性,然后生成了一个随机的SSM模型(通过调用random_SSM函数)。接着,创建了一个随机的输入序列x_tensor,这个序列有2个时间步和5个特征(这可能与多个输入信号或并行模拟有关)。最后,调用run_SSM函数来模拟SSM在这个输入序列下的行为,并打印出输出序列的形状。

总的来说,这段代码的作用是模拟一个随机生成的离散状态空间模型在给定的随机输入序列下的行为,并输出模拟的结果。需要注意的是,代码中的 N 应该在调用random_SSM之前被定义或作为参数传入,以避免在run_SSM函数中出现未定义变量的错误(在当前代码片段中, N 是在__main__部分定义的)。此外,输入张量x_tensor的形状可能需要根据具体的SSM模型和模拟需求进行调整。

另外,需要注意的是,代码中x_tensor的形状是(2, 5),这意味着有两个特征在5个时间步上的值,而不是两个并行的输入数据。在每个时间步上,系统接收一个5维的输入向量,并产生一个一维的输出向量(由于 C 是(1, N )维度的矩阵)。torch.stack的作用是将这些输出向量重新整合,形成一个完整的输出向量。

4.1.5 HiPPO算法初始化状态矩阵

前面我们完成了对SSM基本模型的介绍,在状态转移方面,我们采用以下实现方式:

def random_SSM(N):
   A = torch.randn(size=(N,N))
   B = torch.randn(size=(N,1))
   C = torch.randn(size=(1,N))
   return A,B,C

然而,仅通过随机初始化矩阵的方式往往可能无法有效地使模型学会捕捉长距离的依赖关系。相比之下,HiPPO算法专门设计用来弥补序列建模中远距离依赖的不足。HiPPO算法的核心思想在于它生成了一个隐藏状态,该状态能够记住序列前部的历史信息。

HiPPO算法指定了一类特殊的矩阵 A ,当这些矩阵被纳入SSM的方程中时,可以使状态 x ( t )能够记住输入 u ( t )的历史信息。这些特殊矩阵被称为HiPPO矩阵,它们具有特定的数学形式,可以有效地捕捉长期依赖关系。

修改后的代码如下:

import torch

def SSM(N):
   A = make_HiPPO(N)
   B = torch.randn(size=(N,1))
   C = torch.randn(size=(1,N))
   return A,B,C

def make_HiPPO(N):
   # 使用torch.arange 创建一个0~N-1的一维张量
   P = torch.sqrt(torch.tensor(1.0) + 2 * torch.arange(N, dtype=torch.float32))

   # 使用torch.unsqueeze增加维度以创建列向量和行向量
   P_col = P.unsqueeze(1)  # 列向量
   P_row = P.unsqueeze(0)  # 行向量

   # 使用广播机制计算A矩阵
   A = P_col * P_row

   # 使用torch.tril 获取下三角矩阵,并减去对角线上的值
   A = torch.tril(A) - torch.diag_embed(torch.arange(N, dtype=torch.float32))

   # 返回负矩阵
   return -A

读者可以自行将其代入4.1.4节的模型部分进行替换和比较。

4.1.6 基于S4架构的Mamba模型

通过对前面内容的讲解,可以看到SSM+HiPPO=S4。其主要工作是将HiPPO中的矩阵 A (称为HiPPO矩阵)转换为正规矩阵(正规矩阵可以分解为对角矩阵)和低秩矩阵的和,以此提高计算效率。S4通过这种分解将计算复杂度降低了,其中 N 是HiPPO矩阵的维度, L 是序列的长度。

在处理长度为16 000的序列的语音分类任务中,S4模型将专门设计的语音卷积神经网络(Speech CNN)的测试错误率降低了一半。相比之下,所有的循环神经网络和Transformer基线模型都无法学习,错误率均在70%以上。 yVkhVUBpt9EYnxZJGZwRrX3qBwWMMIrlveJMLzHCpobDDbiLEbGdkpB5Oe1kLDIA

点击中间区域
呼出菜单
上一章
目录
下一章
×