一条序列 = 大量训练样本
假设你有一条长度为 6 的序列:[我, 喜欢, 学习, AI, Infra, 技术]
训练时,模型并不是只在最后算一次 loss。而是序列中的每个位置都同时作为一个训练样本。具体来说,这条序列被同时拆成了 5 个"输入→目标"对:
位置1: 输入 [我] → 预测 "喜欢" 位置2: 输入 [我, 喜欢] → 预测 "学习" 位置3: 输入 [我, 喜欢, 学习] → 预测 "AI" 位置4: 输入 [我, 喜欢, 学习, AI] → 预测 "Infra" 位置5: 输入 [我, 喜欢, 学习, AI, Infra] → 预测 "技术"也就是说,一条长度为 N 的序列,一次性提供了 N-1 个训练样本。这比"只拿最后一个 token 当目标"高效了 N-1 倍。
关键机制:Causal Mask(因果遮罩)
你可能会问:模型一次前向传播怎么处理这些不同长度的输入?答案是——输入始终是完整序列,但通过 attention mask 让每个位置只能看到它前面的 token。
Self-Attention 的计算是Attention(Q, K, V) = softmax(QK^T / √d) · V。在这个矩阵乘法中,QK^T会产生一个[seq_len, seq_len]的 score 矩阵。Causal Mask 就是把上三角部分设为-∞:
Score 矩阵(6×6): 我 喜欢 学习 AI Infra 技术 我 [ ✓ -∞ -∞ -∞ -∞ -∞ ] ← 位置1只能看到自己 喜欢 [ ✓ ✓ -∞ -∞ -∞ -∞ ] ← 位置2看到1,2 学习 [ ✓ ✓ ✓ -∞ -∞ -∞ ] ← 位置3看到1,2,3 AI [ ✓ ✓ ✓ ✓ -∞ -∞ ] Infra [ ✓ ✓ ✓ ✓ ✓ -∞ ] 技术 [ ✓ ✓ ✓ ✓ ✓ ✓ ]softmax 之后,-∞变成 0,所以位置 i 的注意力权重只分配给位置 1 到 i。这就保证了"不能偷看未来"。
前向传播一次,得到所有位置的预测
经过 Transformer 各层后,每个位置都会输出一个隐藏向量。最后通过 LM Head(线性层 + Softmax)得到每个位置的下一个 token 概率:
位置1的隐藏状态 → softmax → P(next="喜欢" | "我") 位置2的隐藏状态 → softmax → P(next="学习" | "我,喜欢") 位置3的隐藏状态 → softmax → P(next="AI" | "我,喜欢,学习") ...这些概率全是在一次前向传播中同时算出来的。这就是 GPU 并行计算的优势——整条序列的矩阵运算一次完成。
Loss 也是所有位置同时算
得到所有位置的预测概率后,loss 计算就是拿每个位置的预测和下一个真实 token 做交叉熵,然后求平均:
L = -1/(N-1) · Σᵢ log P(token_{i+1} | token_{1..i}) = -1/5 · [log P("喜欢"|"我") + log P("学习"|"我,喜欢") + log P("AI"|"我,喜欢,学习") + log P("Infra"|"我,喜欢,学习,AI") + log P("技术"|"我,喜欢,学习,AI,Infra")]这个标量 loss 反向传播时,梯度会从每个位置流回去,更新模型参数让所有位置的预测都变得更准。
代码层面其实很简单
PyTorch 里核心代码就几行:
# input_ids: [batch_size, seq_len],比如 [8, 2048]input_ids=batch["input_ids"]# 输入是前 N-1 个 token,标签是后 N-1 个 tokeninputs=input_ids[:,:-1]# [8, 2047]labels=input_ids[:,1:]# [8, 2047]# 一次前向,得到每个位置的 logitslogits=model(inputs)# [8, 2047, vocab_size]# 展平后算交叉熵loss=F.cross_entropy(logits.view(-1,vocab_size),# [8*2047, vocab_size]labels.view(-1)# [8*2047])注意model(inputs)内部会自动应用 causal mask。最终loss.backward()一次反向传播就更新了参数。
和推理的本质区别
训练时,因为所有位置的"正确答案"都已知(就是序列本身),所以可以一次前向传播算出所有位置的 loss,一次反向传播更新参数。效率极高。
推理时,你不知道下一个 token 是什么,所以只能逐个生成:先根据输入生成第 1 个 token,再把它拼回去生成第 2 个,以此类推。这就是 Decode 阶段效率低的原因——每个 token 都要单独做一次前向传播(虽然 batch 内的请求可以并行)。
所以训练是"已知答案,一次批改整张卷子",推理是"不知道答案,一道题一道题地做"。