一条序列 = 大量训练样本
假设你有一条长度为 6 的序列:[我, 喜欢, 学习, AI, Infra, 技术]
训练时,模型并不是只在最后算一次 loss。而是序列中的每个位置都同时作为一个训练样本。具体来说,这条序列被同时拆成了 5 个"输入→目标"对:
位置1: 输入 [我] → 预测 "喜欢"
位置2: 输入 [我, 喜欢] → 预测 "学习"
位置3: 输入 [我, 喜欢, 学习] → 预测 "AI"
位置4: 输入 [我, 喜欢, 学习, AI] → 预测 "Infra"
位置5: 输入 [我, 喜欢, 学习, AI, Infra] → 预测 "技术"
也就是说,一条长度为 N 的序列,一次性提供了 N-1 个训练样本。这比"只拿最后一个 token 当目标"高效了 N-1 倍。
关键机制:Causal Mask(因果遮罩)
你可能会问:模型一次前向传播怎么处理这些不同长度的输入?答案是——输入始终是完整序列,但通过 attention mask 让每个位置只能看到它前面的 token。
Self-Attention 的计算是 Attention(Q, K, V) = softmax(QK^T / √d) · V。在这个矩阵乘法中,QK^T 会产生一个 [seq_len, seq_len] 的 score 矩阵。Causal Mask 就是把上三角部分设为 -∞:
Score 矩阵(6×6):
我 喜欢 学习 AI Infra 技术
我 [ ✓ -∞ -∞ -∞ -∞ -∞ ] ← 位置1只能看到自己
喜欢 [ ✓ ✓ -∞ -∞ -∞ -∞ ] ← 位置2看到1,2
学习 [ ✓ ✓ ✓ -∞ -∞ -∞ ] ← 位置3看到1,2,3
AI [ ✓ ✓ ✓ ✓ -∞ -∞ ]
Infra [ ✓ ✓ ✓ ✓ ✓ -∞ ]
技术 [ ✓ ✓ ✓ ✓ ✓ ✓ ]
softmax 之后,-∞ 变成 0,所以位置 i 的注意力权重只分配给位置 1 到 i。这就保证了"不能偷看未来"。
前向传播一次,得到所有位置的预测
经过 Transformer 各层后,每个位置都会输出一个隐藏向量。最后通过 LM Head(线性层 + Softmax)得到每个位置的下一个 token 概率:
位置1的隐藏状态 → softmax → P(next="喜欢" | "我")
位置2的隐藏状态 → softmax → P(next="学习" | "我,喜欢")
位置3的隐藏状态 → softmax → P(next="AI" | "我,喜欢,学习")
...
这些概率全是在一次前向传播中同时算出来的。这就是 GPU 并行计算的优势——整条序列的矩阵运算一次完成。
Loss 也是所有位置同时算
得到所有位置的预测概率后,loss 计算就是拿每个位置的预测和下一个真实 token 做交叉熵,然后求平均:
L = -1/(N-1) · Σᵢ log P(token_{i+1} | token_{1..i})
= -1/5 · [log P("喜欢"|"我")
+ log P("学习"|"我,喜欢")
+ log P("AI"|"我,喜欢,学习")
+ log P("Infra"|"我,喜欢,学习,AI")
+ log P("技术"|"我,喜欢,学习,AI,Infra")]
这个标量 loss 反向传播时,梯度会从每个位置流回去,更新模型参数让所有位置的预测都变得更准。
代码层面其实很简单
PyTorch 里核心代码就几行:
# input_ids: [batch_size, seq_len],比如 [8, 2048]
input_ids = batch["input_ids"]
# 输入是前 N-1 个 token,标签是后 N-1 个 token
inputs = input_ids[:, :-1] # [8, 2047]
labels = input_ids[:, 1:] # [8, 2047]
# 一次前向,得到每个位置的 logits
logits = model(inputs) # [8, 2047, vocab_size]
# 展平后算交叉熵
loss = F.cross_entropy(
logits.view(-1, vocab_size), # [8*2047, vocab_size]
labels.view(-1) # [8*2047]
)
注意 model(inputs) 内部会自动应用 causal mask。最终 loss.backward() 一次反向传播就更新了参数。
和推理的本质区别
训练时,因为所有位置的"正确答案"都已知(就是序列本身),所以可以一次前向传播算出所有位置的 loss,一次反向传播更新参数。效率极高。
推理时,你不知道下一个 token 是什么,所以只能逐个生成:先根据输入生成第 1 个 token,再把它拼回去生成第 2 个,以此类推。这就是 Decode 阶段效率低的原因——每个 token 都要单独做一次前向传播(虽然 batch 内的请求可以并行)。
所以训练是"已知答案,一次批改整张卷子",推理是"不知道答案,一道题一道题地做"。

3万+

被折叠的 条评论
为什么被折叠?



