VGGT代码阅读笔记

了解 Register Tokens

Register Tokens 相当于给模型提供了几个“空白笔记本”。模型可以将注意力机制中冗余的、全局的或非空间的信息存储在这些 Register 中,而不需要去“污染”代表实际像素或文字的 Token。 在 VGGT 中 一般为 4 个

大白话: 注意力交互过程中的 垃圾桶

Image 预处理

图像的 分辨率是 518, 这个是写死的, 如果图像不能满足 518 的分辨率,那么 默认会对于新图像进行 crop.
并且 基于 Transformer 的 Size 必须得被14 整除, 对于新的尺寸进行 ''crop"

  # Calculate height maintaining aspect ratio, divisible by 14
 new_height = round(height * (new_width / width) / 14) * 14
## 新的图像必须被 Resize 成 对应的 图像 size
 img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)

Forward 函数

输入: 仅仅是图像的 shape: [S, 3, H, W]
输出:pose, depth, world_points (包含world_points 的 confidence)
如果提供了 query point, 那么还包含 point 的 track 信息以及对应的 confidence:

 track (torch.Tensor): Point tracks with shape [B, S, N, 2]
 conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]

Aggregator 函数 (预处理)

aggregated_tokens_list, patch_start_idx = self.aggregator(images)

Aggregator 的核心思想是 alternating attention: 在 frame 和 global 两种 attention 之间交替

frame attention:每张图自己内部做 attention,本质上是单帧内 token 交互;
global attention:把所有帧的 token 串起来做 attention,本质上是跨帧交互

本质就是 先让每帧内部整理语义,再让多帧之间交换几何信息,然后反复交替。

  • 首先对于 Image 进行 Normalize 的归一化
 # Normalize images and reshape for patch embed [25,3,350,518]
images = (images - self._resnet_mean) / self._resnet_std
  • 将图像 通过 DINO 的编码器, 打成 patch, 得到含有 DINO 语义的 token
 # Reshape to [B*S, C, H, W] for patch embedding, 输出的是 提取 DINO feature 的  token
 images = images.view(B * S, C_in, H, W)
 patch_tokens = self.patch_embed(images)
 
 含有 DINO 语义的 token shape (25,925,1024), 25张图,每张图打成925个 token
 if isinstance(patch_tokens, dict):
      patch_tokens = patch_tokens["x_norm_patchtokens"]
  • 引入camera_token 和 4 个 register_token ,来作为最终的注意力网络的输入. 第一帧的 token 和其他帧不一样,因为将 第一帧作为参考帧。
  #  (1st vs others) 初始化 2个 camera_token, 第一个给第一帧, 第二个给其他帧
  camera_token = slice_expand_and_flatten(self.camera_token, B, S)

  初始化 2个 register_token, 第一个给第一帧, 第二个给其他帧
  register_token = slice_expand_and_flatten(self.register_token, B, S)

# camera_token shape(25,1,1024)
# register_tokenshape(25,4,1024)
# camera_token shape(25,925,1024)
## final token shape:  (25,930,1024)
tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
  • 在 Transformer 里面有 Position Encoding, 所以我们需要对 patch 的位置进行编码

每个 token 对应的二维 patch 坐标 (y, x),不是像素坐标,而是 patch 网格坐标

## 对于 其他 patch,求取坐标 (0,0)-> (25,37)
if self.rope is not None:
   pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)

if self.patch_start_idx > 0:
   # 对于 camera_token 和 register_token 初始化其坐标为 (0,0)
   pos = pos + 1
   pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
   pos = torch.cat([pos_special, pos], dim=1)

Aggregator 函数 (Attention模块)

一共有 24 层, 每一层 都包含一个帧内的 frame attention 和 跨帧 的 global attention。Attention 的实现都是标准的 Attention 实现,

其中 frame attentionglobal attention 网络结构完全一样,都是 self-attention 的架构,区别只是 输入的 token 维度不同

  • 在经过 frame attention 对应的 token 的 shape: (25,930,1024), 930 个 token 交互信息,不涉及 跨帧(25)的信息交流

  • 在经过 global attention 对应的 token 的 shape: (1,25*930,1024), 23250 个 token 交互信息, 在跨帧(25)的信息进行交流

## 一共循环 L=24 层
for _ in range(self.aa_block_num):

	## self.aa_order = ['frame', 'global']
	## 先经过 frame attention, 在经过 global attention
  for attn_type in self.aa_order:
  		 ## frame_idx 表示调用第几个 frame attention (一共24个)
      if attn_type == "frame":
          tokens, frame_idx, frame_intermediates = self._process_frame_attention(
              tokens, B, S, P, C, frame_idx, pos=pos
          )
          
      ## global_idx 表示调用第几个 global attention (一共24个)
      elif attn_type == "global":
          tokens, global_idx, global_intermediates = self._process_global_attention(
              tokens, B, S, P, C, global_idx, pos=pos
          )
      else:
          raise ValueError(f"Unknown attention type: {attn_type}")

  for i in range(len(frame_intermediates)):
      # concat frame and global intermediates, [B x S x P x 2C]
      concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
      output_list.append(concat_inter)
  • 返回每一层经过 attention 之后的 frame token 和 global token
*       ## 返回24 层每一层的 frame attention 之后的 token 和 global attention之后的token
        ## start_idx=5; 因为由一个 camera_token 和 4个 register_token
        aggregated_tokens_list, patch_start_idx = self.aggregator(images)

Decoder 模块 (根据不同的 head 得到不同的属性)

Page4D 对于 global feature 可视化

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值