Mamba2模型的实现

深入探索Mamba模型架构与应用 - 商品搜索 - 京东

 DeepSeek大模型高性能核心技术与多模态融合开发 - 商品搜索 - 京东

Mamba2在原有的Mamba模型的基础上融入了注意力机制,这一创新性的改进赋予了模型对远程信息的关注能力。通过引入注意力机制,Mamba2不仅保留了原Mamba模型对序列信息敏感的优点,还能够有效地捕获并处理长距离依赖关系。这一变革性的增强使得Mamba2在处理复杂任务时更加灵活和全面,大大提高了模型的性能和应用范围。

12.1.1  Mamba2核心组件SSD详解

结构化状态空间对偶性是在核心部分添加注意力机制,即将原有的SSM替换成带有注意力组件的新型架构,在具体实现上,我们可以参照GLM架构的注意力实现,首先完成其中的注意力机制,代码如下:

def segsum(x: Tensor, device: Device = None) -> Tensor:  
    """Stable segment sum calculation.  
  
    `exp(segsum(A))` 生成一个1-半可分矩阵,等同于一个标量SSM(Scalar SSM,可能是指某种特定的半可分矩阵)
    """  
      
    # 获取输入Tensor x的最后一个维度的大小,通常代表时间序列的长度  
    T = x.size(-1)  
      
    # 使用repeat函数扩展x的维度,使其在最后一个维度上增加一个与T相同大小的维度e  
    # 这实际上是为后续的矩阵操作做准备,生成一个二维的矩阵,其中每一行都是原始x的复制  
    x = repeat(x, "... d -> ... d e", e=T)
      
    # 创建一个下三角矩阵,其中对角线下方的元素为1(True),其余为0(False)
    # 这个矩阵将用作后续操作的掩码
    mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)  
      
    # 使用上面创建的掩码,将x中上三角部分的元素替换为0  
    x = x.masked_fill(~mask, 0)  
      
    # 沿着倒数第二个维度(即新扩展的维度e)计算累积和  
    x_segsum = torch.cumsum(x, dim=-2)  
      
    # 创建一个新的下三角矩阵,但这次包括对角线元素  
    mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)  
      
    # 使用新的掩码,将x_segsum中上三角部分的元素替换为负无穷大  
    # 这样做可能是为了在后续的计算中忽略这些值,或者使它们在softmax等操作中变得非常小  
    x_segsum = x_segsum.mask
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值