Ultralytics:解读TransformerBlock模块

在这里插入图片描述

前言

相关介绍

Ultralytics 简介

Ultralytics 基于多年的计算机视觉和人工智能基础研究,创建了最先进的 (SOTA) YOLO 模型。我们的模型不断更新性能和灵活性,快速、准确且易于使用。他们擅长对象检测、跟踪、实例分割、语义分割、图像分类和姿势估计任务。

前提条件

  • 熟悉Python、Pytorch

实验环境

Package                  Version
------------------------ ------------
Python                   3.11.8
absl-py                  2.4.0
accelerate               1.13.0
annotated-doc            0.0.4
anyio                    4.13.0
calflops                 0.3.2
certifi                  2026.4.22
charset-normalizer       3.4.7
click                    8.3.3
colorama                 0.4.6
contourpy                1.3.3
cycler                   0.12.1
filelock                 3.29.0
flatbuffers              25.12.19
fonttools                4.62.1
fsspec                   2026.4.0
grpcio                   1.80.0
h11                      0.16.0
hf-xet                   1.5.0
httpcore                 1.0.9
httpx                    0.28.1
huggingface_hub          1.14.0
idna                     3.15
Jinja2                   3.1.6
kiwisolver               1.5.0
Markdown                 3.10.2
markdown-it-py           4.2.0
MarkupSafe               3.0.3
matplotlib               3.10.9
mdurl                    0.1.2
ml_dtypes                0.5.0
mpmath                   1.3.0
networkx                 3.6.1
numpy                    1.26.4
nvidia-cublas-cu12       12.8.3.14
nvidia-cuda-cupti-cu12   12.8.57
nvidia-cuda-nvrtc-cu12   12.8.61
nvidia-cuda-runtime-cu12 12.8.57
nvidia-cudnn-cu12        9.7.1.26
nvidia-cufft-cu12        11.3.3.41
nvidia-cufile-cu12       1.13.0.11
nvidia-curand-cu12       10.3.9.55
nvidia-cusolver-cu12     11.7.2.55
nvidia-cusparse-cu12     12.5.7.53
nvidia-cusparselt-cu12   0.6.3
nvidia-nccl-cu12         2.26.2
nvidia-nvjitlink-cu12    12.8.61
nvidia-nvtx-cu12         12.8.55
onnx                     1.19.0
onnxruntime-gpu          1.26.0
onnxslim                 0.1.94
opencv-python            4.6.0.66
packaging                26.2
pillow                   12.2.0
pip                      24.0
polars                   1.40.1
polars-runtime-32        1.40.1
protobuf                 7.34.1
psutil                   7.2.2
pycocotools              2.0.11
Pygments                 2.20.0
pyparsing                3.3.2
python-dateutil          2.9.0.post0
PyYAML                   6.0.3
regex                    2026.5.9
requests                 2.34.1
rich                     15.0.0
safetensors              0.7.0
scipy                    1.16.0
setuptools               65.5.0
shellingham              1.5.4
six                      1.17.0
sympy                    1.14.0
tabulate                 0.10.0
tensorboard              2.20.0
tensorboard-data-server  0.7.2
tokenizers               0.22.2
torch                    2.7.1+cu128
torchaudio               2.7.1+cu128
torchvision              0.22.1+cu128
tqdm                     4.67.3
transformers             5.8.1
triton                   3.3.1
typer                    0.25.1
typing_extensions        4.15.0
ultralytics              8.4.58
ultralytics-thop         2.0.19
urllib3                  2.7.0
Werkzeug                 3.1.8

TransformerBlock(视觉Transformer模块)

TransformerBlock 是一个完整的 视觉 Transformer(ViT)模块,专为处理 2D 特征图 而设计。它集成了通道调整(可选)、可学习位置嵌入和多个 Transformer 层,使得模型能够直接对图像特征图进行全局自注意力建模。该实现参考了论文 “An Image is Worth 16x16 Words” 和 Ultralytics 的 YOLO 系列,常用于增强 CNN 特征的全局上下文表达能力。


代码实现

import cv2
import math
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn

def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p


class Conv(nn.Module):
    """Standard convolution module with batch normalization and activation.

    Attributes:
        conv (nn.Conv2d): Convolutional layer.
        bn (nn.BatchNorm2d): Batch normalization layer.
        act (nn.Module): Activation function layer.
        default_act (nn.Module): Default activation function (SiLU).
    """

    default_act = nn.SiLU()  # default activation

    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initialize Conv layer with given parameters.

        Args:
            c1 (int): Number of input channels.
            c2 (int): Number of output channels.
            k (int): Kernel size.
            s (int): Stride.
            p (int, optional): Padding.
            g (int): Groups.
            d (int): Dilation.
            act (bool | nn.Module): Activation function.
        """
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        """Apply convolution, batch normalization and activation to input tensor.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            (torch.Tensor): Output tensor.
        """
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        """Apply convolution and activation without batch normalization.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            (torch.Tensor): Output tensor.
        """
        return self.act(self.conv(x))

class TransformerLayer(nn.Module):
    """Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)."""

    def __init__(self, c: int, num_heads: int):
        """Initialize a self-attention mechanism using linear transformations and multi-head attention.

        Args:
            c (int): Input and output channel dimension.
            num_heads (int): Number of attention heads.
        """
        super().__init__()
        self.q = nn.Linear(c, c, bias=False)
        self.k = nn.Linear(c, c, bias=False)
        self.v = nn.Linear(c, c, bias=False)
        self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
        self.fc1 = nn.Linear(c, c, bias=False)
        self.fc2 = nn.Linear(c, c, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply a transformer block to the input x and return the output.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            (torch.Tensor): Output tensor after transformer layer.
        """
        x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
        return self.fc2(self.fc1(x)) + x

class TransformerBlock(nn.Module):
    """Vision Transformer block based on https://arxiv.org/abs/2010.11929.

    This class implements a complete transformer block with optional convolution layer for channel adjustment, learnable
    position embedding, and multiple transformer layers.

    Attributes:
        conv (Conv, optional): Convolution layer if input and output channels differ.
        linear (nn.Linear): Learnable position embedding.
        tr (nn.Sequential): Sequential container of transformer layers.
        c2 (int): Output channel dimension.
    """

    def __init__(self, c1: int, c2: int, num_heads: int, num_layers: int):
        """Initialize a Transformer module with position embedding and specified number of heads and layers.

        Args:
            c1 (int): Input channel dimension.
            c2 (int): Output channel dimension.
            num_heads (int): Number of attention heads.
            num_layers (int): Number of transformer layers.
        """
        super().__init__()
        self.conv = None
        if c1 != c2:
            self.conv = Conv(c1, c2)
        self.linear = nn.Linear(c2, c2)  # learnable position embedding
        self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
        self.c2 = c2

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward propagate the input through the transformer block.

        Args:
            x (torch.Tensor): Input tensor with shape [b, c1, h, w].

        Returns:
            (torch.Tensor): Output tensor with shape [b, c2, h, w].
        """
        if self.conv is not None:
            x = self.conv(x)
        b, _, h, w = x.shape
        p = x.flatten(2).permute(2, 0, 1)
        return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, h, w)

功能

  • 通道调整:若输入通道 c1 不等于输出通道 c2,则先通过一个 Conv 层(含 BN 和激活)进行映射,否则跳过。
  • 可学习位置嵌入:对展平后的特征图(序列长度为 h*w,每个 token 维度为 c2)应用一个线性层(nn.Linear),生成与输入相加的位置编码,这使得模型能感知 token 在空间中的顺序。
  • 多头自注意力:通过多个 TransformerLayer 堆叠,对 token 序列进行全局自注意力建模,捕获长距离依赖。
  • 形状恢复:将输出序列重塑回 [b, c2, h, w] 的张量,保持与输入相同的空间尺寸。

初始化参数

参数类型说明
c1int输入特征图的通道数
c2int输出特征图的通道数
num_headsint每个 Transformer 层的注意力头数(必须能整除 c2
num_layersintTransformer 层的堆叠数量

c2 必须能被 num_heads 整除,否则 nn.MultiheadAttention 会报错。


前向方法

  • forward(x):输入 x(形状 [b, c1, h, w]),输出 [b, c2, h, w]

详细步骤

  1. c1 != c2,应用 self.conv 调整通道。
  2. 获取批次大小 b、通道 _、高 h、宽 w
  3. 将特征图展平为 [b, c2, h*w],然后转置为 [h*w, b, c2],得到 token 序列。
  4. 计算位置嵌入 self.linear(p)(形状相同),并加到 p 上。
  5. 通过 self.tr(多个 TransformerLayer)进行自注意力处理。
  6. 将输出转置回 [b, c2, h*w] 并重塑为 [b, c2, h, w]

使用示例

在这里插入图片描述

if __name__ == '__main__':
    # 设置参数
    batch, c1, h, w = 2, 32, 16, 16  # 输入特征图尺寸
    c2, num_heads, num_layers = 64, 4, 2  # 输出通道、头数、层数

    # 创建随机输入特征图
    x = torch.randn(batch, c1, h, w)

    # 创建 TransformerBlock
    block = TransformerBlock(c1=c1, c2=c2, num_heads=num_heads, num_layers=num_layers)

    # 前向传播
    with torch.no_grad():
        out = block(x)
    print("输入形状:", x.shape)   # [2, 32, 16, 16]
    print("输出形状:", out.shape) # [2, 64, 16, 16]

    # 使用真实图像演示(需要先转成特征图,这里直接使用灰度图模拟单通道特征)
    img_path = "cat_640x640.png"
    img_bgr = cv2.imread(img_path)
    if img_bgr is not None:
        # 缩放到 64x64,并转为灰度图模拟特征图(单通道)
        img_gray = cv2.cvtColor(cv2.resize(img_bgr, (64, 64)), cv2.COLOR_BGR2GRAY)
        # 转换为张量 [1, 1, 64, 64] 并归一化
        x_img = torch.from_numpy(img_gray).float().unsqueeze(0).unsqueeze(0) / 255.0
        # 创建适用于单通道的 TransformerBlock(c1=1, c2=8, 保证 c2 % num_heads == 0)
        block_img = TransformerBlock(c1=1, c2=8, num_heads=2, num_layers=1)
        with torch.no_grad():
            out_img = block_img(x_img)
        # 可视化输入和输出的第一个通道
        inp = x_img[0, 0].cpu().numpy()
        outp = out_img[0, 0].cpu().numpy()
        def norm(arr):
            return (arr - arr.min()) / (arr.max() - arr.min() + 1e-8)
        plt.figure(figsize=(12, 5), constrained_layout=True)
        plt.subplot(1, 3, 1)
        plt.imshow(inp, cmap='gray')
        plt.title("Input")
        plt.axis("off")
        plt.subplot(1, 3, 2)
        plt.imshow(norm(outp), cmap='gray')
        plt.title("Output (Ch0)")
        plt.axis("off")
        plt.subplot(1, 3, 3)
        plt.imshow(np.abs(norm(inp) - norm(outp)), cmap='hot')
        plt.title("Difference")
        plt.axis("off")
        plt.savefig("transformer_block_demo.png", dpi=150)
        print("可视化已保存为 transformer_block_demo.png")

在这里插入图片描述

输出示例

输入形状: torch.Size([2, 32, 16, 16])
输出形状: torch.Size([2, 64, 16, 16])
可视化已保存为 transformer_block_demo.png

流程示意图

在这里插入图片描述


代码解读

  • __init__

    • self.conv:若 c1 != c2,用 Conv 层调整通道,保持空间尺寸不变(k=1 默认)。
    • self.linear:可学习的位置嵌入,对每个位置(h*w 个)赋予一个可学习的向量,形状为 (c2,)
    • self.tr:由 num_layersTransformerLayer 组成的序列。
  • forward

    • 通道调整后,x 形状为 (b, c2, h, w)
    • x.flatten(2) 得到 (b, c2, h*w)permute(2, 0, 1) 得到 (h*w, b, c2)
    • self.linear(p) 对每个 token 独立施加线性变换(相当于可学习的位置嵌入),与 p 相加。
    • 依次经过各 TransformerLayer,输出仍为 (h*w, b, c2)
    • permute(1, 2, 0) 得到 (b, c2, h*w)reshape(b, c2, h, w) 恢复空间结构。

注意事项

  1. 通道数必须能被 num_heads 整除:因为 TransformerLayer 内部 nn.MultiheadAttention 要求 embed_dim 能被 num_heads 整除。
  2. 序列长度h*w 是 token 数量,当特征图较大时(如 64×64 = 4096),自注意力的计算复杂度为 O(N²),内存和计算开销会急剧增加,建议在小分辨率特征图上使用。
  3. 位置嵌入可学习:与固定正弦编码不同,这里采用可学习的位置嵌入,但需要注意其参数量为 (h*w) * c2,对于大特征图可能较大。
  4. 无归一化和激活:内部 TransformerLayer 去除了 LayerNorm 和激活函数,可能影响训练稳定性,建议在堆叠时添加适当的正则化。
  5. 输入格式:输入必须是 4D 张量(b, c, h, w),支持任意空间尺寸(但需固定,因为位置嵌入的参数量与 h*w 绑定,一旦创建,输入尺寸必须固定)。

优缺点

优点
  1. 全局感受野:自注意力机制让每个位置能直接与其他所有位置交互,捕获长距离依赖,优于卷积的局部感受野。
  2. 即插即用:可无缝嵌入 CNN 骨干网络中,增强特征表示。
  3. 灵活性:通过 num_layersnum_heads 可调节模型容量。
  4. 位置感知:可学习的位置嵌入让模型能利用空间位置信息。
缺点
  1. 计算量大:自注意力的复杂度与序列长度平方成正比,高分辨率特征图下难以承受。
  2. 参数多:可学习位置嵌入随序列长度增长,增加模型体积。
  3. 训练可能不稳定:由于内部无 LayerNorm 和 Dropout,梯度可能爆炸或过拟合,需小心调参。
  4. 通道数限制c2 必须能被 num_heads 整除,限制了设计自由度。

在 YOLOv8 等模型中,TransformerBlock 通常作为 C2f 内部的增强模块,或用于构建 RT-DETR 的编码器。使用时建议在低分辨率特征图(如 16×16 或 8×8)上启用,并注意学习率预热和梯度裁剪。

参考文献

[1] https://docs.ultralytics.com/
[2] https://github.com/ultralytics/ultralytics.git

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

FriendshipT

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值