|
|
在 Google Colab 中运行
|
|
|
在 GitHub 上查看源代码
|
为了提高 Gemma 4 模型的推理速度,我们发布了新的自回归“草稿”模型系列,与主要阵容一起发布。草稿模型不会仅仅依赖于主要的 Gemma 4 模型(称为“目标”模型),而是在目标模型处理一个 token 所需的时间内,以自回归方式预测多个 token。这种技术也称为推测性解码。
在草稿模型预测多个草稿 token 后,目标模型现在只需验证这些建议的草稿 token。验证是并行完成的,因此可以大幅加快推理速度。它减少了目标模型必须为每个 token 执行的前向传递次数。由于我们的草稿模型会生成一系列 token 进行验证,因此我们将其称为多 token 预测 (MTP) 头。

为 Gemma 4 系列发布的草稿模型很小,并引入了一些增强功能,以提高草稿 token 的质量并进一步加快推理速度,例如使用目标模型激活和 KV 缓存来获得更好的预测。
这些增强功能可显著加快解码速度,同时保证相似的质量,使这些检查点非常适合低延迟和设备端应用。
安装 Python 软件包
安装运行 Gemma 4 和 Gemma 4 助理模型所需的 Hugging Face 库。
# Install PyTorch & other librariespip install torch accelerate# Install the transformers librarypip install transformers
加载模型
对于每个目标模型(Gemma 4 模型中的主要模型之一),都有一个助理来帮助加快推理速度。因此,您将加载两个模型:
- 目标 (例如
google/gemma-4-E2B-it):完整的 Gemma 4 目标模型 - 草稿 (例如
google/gemma-4-E2B-it-assistant):轻量级 4 层 MTP 草稿模型,用于提出候选 token
请注意,由于该模型有助于较大的模型选择要预测的 token,因此 草稿模型 通常称为 助理模型。
使用 transformers 库,使用 AutoProcessor 和 AutoModelForCausalLM 类创建 processor 和 model 的实例,如以下代码示例所示:
TARGET_MODEL_ID = "google/gemma-4-E2B-it" # @param ["google/gemma-4-E2B-it","google/gemma-4-E4B-it", "google/gemma-4-31B-it", "google/gemma-4-26B-A4B-it"]
ASSISTANT_MODEL_ID = TARGET_MODEL_ID + "-assistant"
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
# Target Model
processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
target_model = AutoModelForCausalLM.from_pretrained(
TARGET_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# Assistant Model (the drafter)
assistant_model = AutoModelForCausalLM.from_pretrained(
ASSISTANT_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
[transformers] `torch_dtype` is deprecated! Use `dtype` instead! Loading weights: 0%| | 0/1951 [00:00<?, ?it/s] Loading weights: 0%| | 0/50 [00:00<?, ?it/s]
带有助理的 Gemma 4
幸运的是,在 transformers 中使用助理非常简单,您只需将助理模型传递给 model.generate 函数:
# Process inputs with the `target_model`
messages = [
{
"role": "user",
"content": "Explain the concepts of speculative decoding and MTP in 3 sentences."
}
]
input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=input_text, return_tensors="pt").to(target_model.device)
# `assistant_model=assistant_model` is all you need to enable MTP!
outputs = target_model.generate(
**inputs,
assistant_model=assistant_model,
max_new_tokens=256,
do_sample=False,
)
# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then quickly verified by a larger, more accurate model to produce a final, high-quality output much faster than decoding the large model alone. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly accelerate the inference speed of large language models while maintaining or improving output quality.
在底层,该过程如下所示:
- 草稿模型以自回归方式提出 N 个生成的 token
- 目标模型在一次 前向传递中验证所有 N 个 token
- 接受概率较高的草稿 token
- 拒绝概率较低的草稿 token
- 由于目标模型会执行前向传递,因此无论接受或拒绝多少草稿 token,它始终会自行生成 1 个 token
草稿 token
草稿模型可以生成任意数量的 token,供目标模型验证。不过,目标模型仍然可以选择拒绝某些 token。如果拒绝,则会忽略之后的所有 token。

因此,了解使用不同数量的草稿 token 时的权衡非常重要。
更多草稿 token
如果您起草了许多 token(例如 15 个),则很有可能并非所有 token 都会被接受。因此,浪费计算资源的潜力更大。相反,当接受率较高时,它确实有加快推理速度的趋势。

更少的草稿 token
如果您起草的 token 较少,则接受率往往会更高,因为位置更接近初始提示的 token 更准确。不过,由于只起草了少量 token,因此您从更快的草稿模型中获得的速度提升会减少。

幸运的是,您不必在 transformers 中针对您的用例尝试最佳值,因为您可以将 num_assistant_tokens_schedule 设置为“heuristic”,这会在运行时自动调整起草的 token 数量:
- 所有 token 都被接受 -- 将要起草的 token 数量增加 2,因为草稿模型对于提示非常准确。如果这些 token 也被接受,则增加起草的 token 数量可能会加快速度。
- 任何 token 被拒绝 -- 如果有任何 token 被拒绝,则将要起草的 token 数量减少 1。减少 token 数量可以确保,如果目标模型继续拒绝大多数 token,则不会浪费太多起草的 token。
同样,您可以通过更新草稿模型中的 num_assistant_tokens 来更新草稿 token 的数量,如下所示:
# Update how many draft tokens are generated at the start of inference
assistant_model.generation_config.num_assistant_tokens = 4
# Update how the number of draft tokens are updated ("heuristic" for a dynamic schedule and "constant" for a constant schedule)
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"
# Run with MTP
outputs = target_model.generate(
**inputs,
assistant_model=assistant_model,
max_new_tokens=256,
do_sample=False,
)
# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then verified by a larger, more accurate model to quickly produce a high-quality output. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly speed up the inference process of large language models while maintaining or improving output quality.
在 Google Colab 中运行
在 GitHub 上查看源代码