Vision Transformer 从零实现:理解 ViT 的核心机制
1. 引言
Vision Transformer (ViT) 在 2020 年由 Google 提出,首次证明纯 Transformer 架构可以在图像分类任务上超越 CNN。ViT 的核心思想是将图像切分为固定大小的 patch,每个 patch 视为一个 “token”,然后用标准 Transformer Encoder 处理。
本文目标: 用 PyTorch 从零实现 ViT,并在 CIFAR-10 上训练验证。
2. ViT 架构总览
输入图像 (224×224×3)
↓
Patch Embedding (16×16 patches → 196 tokens × 768 维)
↓
[CLS] Token + Position Embedding
↓
Transformer Encoder × 12
├── LayerNorm → Multi-Head Self-Attention → Residual
└── LayerNorm → FFN (MLP) → Residual
↓
MLP Head → 分类输出
3. 核心实现
3.1 Patch Embedding
import torch
import torch.nn as nn
class PatchEmbedding(nn.Module):
"""将图像切分为 patches 并映射到嵌入空间"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# 用卷积实现 patch 切分 + 线性投影(等价操作,但更高效)
self.projection = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x: (B, C, H, W)
x = self.projection(x) # (B, embed_dim, H/P, W/P)
x = x.flatten(2) # (B, embed_dim, num_patches)
x = x.transpose(1, 2) # (B, num_patches, embed_dim)
return x
3.2 Multi-Head Self-Attention
class MultiHeadAttention(nn.Module):
"""多头自注意力机制"""
def __init__(self, embed_dim=768, num_heads=12, dropout=0.0):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.proj = nn.Linear(embed_dim, embed_dim)
self.attn_drop = nn.Dropout(dropout)
self.proj_drop = nn.Dropout(dropout)
def forward(self, x):
B, N, C = x.shape
# 生成 Q, K, V
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)
q, k, v = qkv.unbind(0)
# 缩放点积注意力
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, heads, N, N)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# 加权聚合
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
3.3 Transformer Encoder Block
class TransformerBlock(nn.Module):
"""单个 Transformer Encoder 块"""
def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(embed_dim)
mlp_hidden = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden, embed_dim),
nn.Dropout(dropout),
)
def forward(self, x):
x = x + self.attn(self.norm1(x)) # Pre-Norm + Residual
x = x + self.mlp(self.norm2(x))
return x
3.4 完整 ViT 模型
class VisionTransformer(nn.Module):
"""完整的 Vision Transformer"""
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
num_classes=10,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
dropout=0.1,
):
super().__init__()
# Patch Embedding
self.patch_embed = PatchEmbedding(
img_size, patch_size, in_channels, embed_dim
)
num_patches = self.patch_embed.num_patches
# CLS Token 和 Position Embedding
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim)
)
self.pos_drop = nn.Dropout(dropout)
# Transformer Encoder 堆叠
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
for _ in range(depth)
])
# 分类头
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
# 初始化
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
B = x.shape[0]
# Patch Embedding
x = self.patch_embed(x) # (B, num_patches, embed_dim)
# 拼接 CLS Token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # (B, num_patches+1, embed_dim)
# 加入位置编码
x = x + self.pos_embed
x = self.pos_drop(x)
# Transformer Encoder
for block in self.blocks:
x = block(x)
# 分类:取 CLS Token 的输出
x = self.norm(x)
cls_output = x[:, 0]
logits = self.head(cls_output)
return logits
4. CIFAR-10 训练
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# 数据预处理(ViT 需要 224×224 输入)
transform_train = transforms.Compose([
transforms.Resize(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2470, 0.2435, 0.2616)),
])
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform_train
)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)
# 创建模型(小版本 ViT,适合 CIFAR-10)
model = VisionTransformer(
img_size=224,
patch_size=16,
num_classes=10,
embed_dim=384, # 缩小嵌入维度
depth=6, # 减少层数
num_heads=6,
mlp_ratio=4.0,
dropout=0.1,
)
# 训练配置
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 训练循环
for epoch in range(100):
model.train()
total_loss = 0
correct = 0
total = 0
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
scheduler.step()
acc = 100. * correct / total
print(f"Epoch {epoch+1}/100 | Loss: {total_loss/len(trainloader):.4f} | Acc: {acc:.2f}%")
5. 注意力可视化
import matplotlib.pyplot as plt
import numpy as np
def visualize_attention(model, image_tensor, device):
"""可视化 ViT 的注意力图"""
model.eval()
hooks = []
attn_weights = []
# 注册 hook 提取注意力权重
def hook_fn(module, input, output):
# 重新计算注意力权重
B, N, C = input[0].shape
qkv = module.qkv(input[0]).reshape(B, N, 3, module.num_heads, -1)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, _ = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * module.scale
attn_weights.append(attn.softmax(dim=-1).detach().cpu())
# 注册到最后一个注意力层
last_attn = model.blocks[-1].attn
hooks.append(last_attn.register_forward_hook(hook_fn))
with torch.no_grad():
model(image_tensor.unsqueeze(0).to(device))
# 可视化 [CLS] token 对所有 patch 的注意力
attn = attn_weights[0][0] # (heads, N, N)
attn = attn.mean(dim=0) # 平均所有头
cls_attn = attn[0, 1:] # CLS 对各 patch 的注意力
grid_size = int(cls_attn.shape[0] ** 0.5)
attn_map = cls_attn.reshape(grid_size, grid_size).numpy()
plt.imshow(attn_map, cmap='hot', interpolation='nearest')
plt.colorbar()
plt.title("CLS Token Attention Map")
plt.savefig("attention_map.png", dpi=150)
for h in hooks:
h.remove()
6. ViT 变体对比
| 模型 | 参数量 | 嵌入维度 | 深度 | 头数 | ImageNet Top-1 |
|---|---|---|---|---|---|
| ViT-Ti | 5.7M | 192 | 12 | 3 | 72.7% |
| ViT-S | 22M | 384 | 12 | 6 | 79.4% |
| ViT-B | 86M | 768 | 12 | 12 | 81.8% |
| ViT-L | 307M | 1024 | 24 | 16 | 85.2% |
| ViT-H | 632M | 1280 | 32 | 16 | 88.6% |
7. ViT vs CNN 深度对比
| 特性 | CNN | ViT |
|---|---|---|
| 归纳偏置 | 局部性、平移不变性 | 无(需大数据学习) |
| 全局建模 | 需要深层堆叠 | 第一层就能全局交互 |
| 计算复杂度 | O(n·k²·c) | O(n²·d)(n 为 token 数) |
| 小数据表现 | 优秀 | 较差(需预训练) |
| 大数据表现 | 饱和 | 持续提升 |
| 可解释性 | 较弱 | 注意力图可可视化 |
8. 总结
ViT 的核心创新在于将图像 patch 化后直接用 Transformer 处理,打破了 CNN 在视觉领域的垄断。关键理解:
- Patch Embedding = 切图 + 线性投影,等价于无重叠卷积
- CLS Token 是一个可学习的分类标记,聚合全局信息
- Position Embedding 对 ViT 至关重要(否则丧失空间信息)
- Pre-Norm(先 LayerNorm 再 Attention)比 Post-Norm 训练更稳定


411

被折叠的 条评论
为什么被折叠?



