利用多令牌预测功能加速 Gemma 4

在 Gemma 4 中,多 token 预测 (MTP) 是一种用于实现高效推测性解码的特定架构。推测性解码是一种加速大语言模型推理的技术。它不是仅依赖于大型目标模型以自回归方式生成 token(一次生成一个 token,其中每个新 token 都依赖于之前的 token),而是使用较小、较快的“草稿模型”提前预测多个 token。然后,目标模型会并行验证这些草稿 token。如果目标模型拒绝草稿 token,它仍会为该位置生成正确的 token(确保该步骤不会浪费),并且草稿模型会从该新的正确 token 恢复预测。

Gemma 4 通过使用以下较小、较快的草稿模型扩展基础模型来实现 MTP。此草稿模型不是独立的,因为它与目标模型共享输入嵌入表,并直接基于其最后一层的激活。这可以显著加快解码速度,同时保证与标准自回归生成完全相同的质量,使这些检查点非常适合低延迟和设备端应用。

推测性解码的工作原理是起草多个 token,并在一次正向传递中验证它们。对于密集模型,每个 token 都使用相同的权重,因此验证多个草稿 token 只会增加极少的开销。混合专家 (MoE) 模型(例如 Gemma 4 26B A4B)的工作方式不同。每个 token 可能会激活不同的专家,因此验证草稿 token 可能需要从内存中加载额外的专家权重,从而抵消起草带来的收益。在较大的批次大小下,跨序列激活的专家通常会有更多重叠,从而提高加载权重的重用率。在批次大小为 1 时,这种重叠是有限的,这就是为什么 26B A4B 起草者可能无法在并行性不佳的硬件平台上提高速度。

MTP 增强功能

Gemma 4 对标准推测性解码流水线进行了一些增强,以提高草稿 token 的质量和效率:

  • 共享输入嵌入:草稿模型与目标模型共享输入嵌入 表。
  • 目标激活:草稿模型使用目标模型 最后一层的激活,将它们与 token 嵌入连接起来,然后将它们向下投影到起草者模型的维度。
  • 高效嵌入器:为避免跨整个词汇表进行预测的昂贵操作,该模型将相似的 token 分组到集群中。它首先识别最有可能的集群,然后将其最终计算限制为仅在这些选定集群中的 token(仅限 E2B 和 E4B)。