了解 Register Tokens
Register Tokens 相当于给模型提供了几个“空白笔记本”。模型可以将注意力机制中冗余的、全局的或非空间的信息存储在这些 Register 中,而不需要去“污染”代表实际像素或文字的 Token。 在 VGGT 中 一般为 4 个
大白话: 注意力交互过程中的 垃圾桶
Image 预处理
图像的 分辨率是 518, 这个是写死的, 如果图像不能满足 518 的分辨率,那么 默认会对于新图像进行 crop.
并且 基于 Transformer 的 Size 必须得被14 整除, 对于新的尺寸进行 ''crop"
# Calculate height maintaining aspect ratio, divisible by 14
new_height = round(height * (new_width / width) / 14) * 14
## 新的图像必须被 Resize 成 对应的 图像 size
img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
Forward 函数
输入: 仅仅是图像的 shape: [S, 3, H, W]
输出:pose, depth, world_points (包含world_points 的 confidence)
如果提供了 query point, 那么还包含 point 的 track 信息以及对应的 confidence:
track (torch.Tensor): Point tracks with shape [B, S, N, 2]
conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
Aggregator 函数 (预处理)
aggregated_tokens_list, patch_start_idx = self.aggregator(images)
Aggregator 的核心思想是 alternating attention: 在 frame 和 global 两种 attention 之间交替。
frame attention:每张图自己内部做 attention,本质上是单帧内 token 交互;
global attention:把所有帧的 token 串起来做 attention,本质上是跨帧交互
本质就是 先让每帧内部整理语义,再让多帧之间交换几何信息,然后反复交替。
- 首先对于 Image 进行 Normalize 的归一化
# Normalize images and reshape for patch embed [25,3,350,518]
images = (images - self._resnet_mean) / self._resnet_std
- 将图像 通过 DINO 的编码器, 打成 patch, 得到含有 DINO 语义的 token
# Reshape to [B*S, C, H, W] for patch embedding, 输出的是 提取 DINO feature 的 token
images = images.view(B * S, C_in, H, W)
patch_tokens = self.patch_embed(images)
含有 DINO 语义的 token shape (25,925,1024), 25张图,每张图打成925个 token
if isinstance(patch_tokens, dict):
patch_tokens = patch_tokens["x_norm_patchtokens"]
- 引入camera_token 和 4 个 register_token ,来作为最终的注意力网络的输入. 第一帧的 token 和其他帧不一样,因为将 第一帧作为参考帧。
# (1st vs others) 初始化 2个 camera_token, 第一个给第一帧, 第二个给其他帧
camera_token = slice_expand_and_flatten(self.camera_token, B, S)
初始化 2个 register_token, 第一个给第一帧, 第二个给其他帧
register_token = slice_expand_and_flatten(self.register_token, B, S)
# camera_token shape(25,1,1024)
# register_tokenshape(25,4,1024)
# camera_token shape(25,925,1024)
## final token shape: (25,930,1024)
tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
- 在 Transformer 里面有 Position Encoding, 所以我们需要对 patch 的位置进行编码
每个 token 对应的二维 patch 坐标 (y, x),不是像素坐标,而是 patch 网格坐标
## 对于 其他 patch,求取坐标 (0,0)-> (25,37)
if self.rope is not None:
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
if self.patch_start_idx > 0:
# 对于 camera_token 和 register_token 初始化其坐标为 (0,0)
pos = pos + 1
pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
pos = torch.cat([pos_special, pos], dim=1)
Aggregator 函数 (Attention模块)
一共有 24 层, 每一层 都包含一个帧内的 frame attention 和 跨帧 的 global attention。Attention 的实现都是标准的 Attention 实现,
其中 frame attention 和 global attention 网络结构完全一样,都是 self-attention 的架构,区别只是 输入的 token 维度不同
-
在经过 frame attention 对应的 token 的 shape: (25,930,1024), 930 个 token 交互信息,不涉及 跨帧(25)的信息交流
-
在经过 global attention 对应的 token 的 shape: (1,25*930,1024), 23250 个 token 交互信息, 在跨帧(25)的信息进行交流
## 一共循环 L=24 层
for _ in range(self.aa_block_num):
## self.aa_order = ['frame', 'global']
## 先经过 frame attention, 在经过 global attention
for attn_type in self.aa_order:
## frame_idx 表示调用第几个 frame attention (一共24个)
if attn_type == "frame":
tokens, frame_idx, frame_intermediates = self._process_frame_attention(
tokens, B, S, P, C, frame_idx, pos=pos
)
## global_idx 表示调用第几个 global attention (一共24个)
elif attn_type == "global":
tokens, global_idx, global_intermediates = self._process_global_attention(
tokens, B, S, P, C, global_idx, pos=pos
)
else:
raise ValueError(f"Unknown attention type: {attn_type}")
for i in range(len(frame_intermediates)):
# concat frame and global intermediates, [B x S x P x 2C]
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
output_list.append(concat_inter)
- 返回每一层经过 attention 之后的 frame token 和 global token
* ## 返回24 层每一层的 frame attention 之后的 token 和 global attention之后的token
## start_idx=5; 因为由一个 camera_token 和 4个 register_token
aggregated_tokens_list, patch_start_idx = self.aggregator(images)

6904

被折叠的 条评论
为什么被折叠?



