MTP (Multi-Token Prediction) 训练指南
概述
MTP (Multi-Token Prediction) 是一种通过并行预测多个未来 token 来加速推理的技术。ROLL 框架支持 MTP 模型的训练,可用于 SFT(监督微调)和 RL(强化学习)场景。
投机采样原理
自回归生成的瓶颈
大语言模型的文本生成是自回归过程:每生成一个 token,都需要完整的前向传播。对于长序列生成(如数学推理),这成为主要的性能瓶颈。
传统自回归生成:
Step 1: 前向传播 → Token 1
Step 2: 前向传播 → Token 2
Step 3: 前向传播 → Token 3
...
每个 token 都需要一次完整的前向传播
投机采样的思想
投机采样(Speculative Decoding)通过"预测-验证"的方式打破这个瓶颈:
- Draft(草稿)阶段:使用一个小型模型快速生成 K 个候选 token
- Verify(验证)阶段:主模型一次前向传播并行验证这 K 个 token
- Accept/Reject:接受符合主模型概率分布的 token,拒绝不符合的
投机采样:
Draft: 小模型快速生成 [Token 1, Token 2, Token 3, Token 4]
Verify: 主模型一次前向传播验证所有候选
结果: 接受前 3 个,拒绝第 4 个
等效于:用 2 次前向传播(1次draft + 1次verify)生成了 3 个 token
为什么能加速?
关键洞察:主模型的前向传播可以并行计算多个位置的 logits。
传统方式下,生成 token 时只计算最后一个位置的 logits,其余位置的计算被浪费了。投机采样利用这一点,用一次主模型前向传播验证多个候选 token,从而提高计算效率。
加速效果取决于什么?
- 接受率:draft model 的输出分布与主模型越接近,接受率越高
- 投机步数:每次投机生成的候选 token 数量
- Draft model 效率:draft model 的推理速度
理想的 draft model 应该:
- 输出分布接近主模型(高接受率)
- 推理速度快(低 draft 开销)
- 参数量小(低内存开销)
什么是 MTP?
MTP (Multi-Token Prediction) 是一种高效的 draft model 实现。与使用独立小模型不同,MTP 与主模型共享权重,具有以下优势:
与普通 LM 的区别
- 普通 LM:用位置 t 的 hidden state 预测位置 t+1 的 token
- MTP:用位置 t 的 hidden state + 位置 t+1 的 token embedding 预测位置 t+2 的 token
普通 LM: H(t) → predict(t+1)
MTP: H(t) + E(t+1) → predict(t+2)
↑ ↑
hidden state embedding