|
|
在 Google Colab 中執行
|
|
|
在 GitHub 上查看來源
|
為提升 Gemma 4 模型的推論速度,我們在推出主要系列模型時,也發布了一系列新的自迴歸「草稿」模型。草稿模型不會只依賴主要的 Gemma 4 模型 (稱為「目標」模型),而是在目標模型處理一個權杖的時間內,自動迴歸預測多個權杖。這項技術也稱為「推測解碼」。
草稿人員預測多個草稿權杖後,目標模型現在只需要驗證這些建議的草稿權杖。驗證作業會平行執行,大幅加快推論速度。這項技術可減少目標模型為每個權杖執行的前向傳遞次數。由於草稿生成器會產生一連串的權杖以供驗證,因此我們將其稱為多權杖預測 (MTP) 標頭。

Gemma 4 系列的草稿模型體積小巧,並導入多項強化功能,可提升草稿權杖的品質,進一步加快推論速度,例如使用目標模型啟用和 KV 快取,取得更準確的預測結果。
這些強化功能可大幅加快解碼速度,同時確保品質相近,因此這些檢查點非常適合低延遲和裝置端應用程式。
安裝 Python 套件
安裝執行 Gemma 4 和 Gemma 4 助理模型所需的 Hugging Face 程式庫。
# Install PyTorch & other librariespip install torch accelerate# Install the transformers librarypip install transformers
載入模型
每個目標模型 (Gemma 4 模型中的主要模型之一) 都配備助理,可加快推論速度。因此,您會載入兩個模型:
- 目標 (例如
google/gemma-4-E2B-it):完整的 Gemma 4 目標模型 - 草稿撰寫者 (例如
google/gemma-4-E2B-it-assistant):輕量型 4 層 MTP 草稿撰寫者,可建議候選權杖
請注意,草稿模型通常稱為助理,因為這個模型會協助較大的模型選擇要預測的詞元。
使用 transformers 程式庫,透過 AutoProcessor 和 AutoModelForCausalLM 類別建立 processor 和 model 的執行個體,如以下程式碼範例所示:
TARGET_MODEL_ID = "google/gemma-4-E2B-it" # @param ["google/gemma-4-E2B-it","google/gemma-4-E4B-it", "google/gemma-4-31B-it", "google/gemma-4-26B-A4B-it"]
ASSISTANT_MODEL_ID = TARGET_MODEL_ID + "-assistant"
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
# Target Model
processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
target_model = AutoModelForCausalLM.from_pretrained(
TARGET_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# Assistant Model (the drafter)
assistant_model = AutoModelForCausalLM.from_pretrained(
ASSISTANT_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
[transformers] `torch_dtype` is deprecated! Use `dtype` instead! Loading weights: 0%| | 0/1951 [00:00<?, ?it/s] Loading weights: 0%| | 0/50 [00:00<?, ?it/s]
搭載 Google 助理的 Gemma 4
幸好,在 transformers 中使用助理相當簡單,只要將助理模型傳遞至 model.generate 函式即可:
# Process inputs with the `target_model`
messages = [
{
"role": "user",
"content": "Explain the concepts of speculative decoding and MTP in 3 sentences."
}
]
input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=input_text, return_tensors="pt").to(target_model.device)
# `assistant_model=assistant_model` is all you need to enable MTP!
outputs = target_model.generate(
**inputs,
assistant_model=assistant_model,
max_new_tokens=256,
do_sample=False,
)
# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then quickly verified by a larger, more accurate model to produce a final, high-quality output much faster than decoding the large model alone. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly accelerate the inference speed of large language models while maintaining or improving output quality.
在幕後,這個程序如下:
- 草稿撰寫者會提議以自迴歸方式生成的 N 個權杖
- 目標模型會在一次正向傳遞中驗證所有 N 個權杖
- 系統會接受高機率的草擬符記
- 系統會拒絕機率較低的草擬符記
- 由於目標模型會執行前向傳遞,因此無論接受或拒絕多少草稿符記,目標模型一律會自行產生 1 個符記
草稿權杖
草稿撰寫者可以為目標模型生成任意數量的權杖,以供驗證。不過,目標模型仍可選擇拒絕特定權杖。如果出現,系統會忽略之後的所有權杖。

因此,使用不同數量的草稿權杖時,請務必瞭解相關取捨。
更多草稿權杖
如果草擬大量權杖 (例如 15 個),則很有可能不會全部獲得核准。因此,浪費的運算資源可能更多。但如果接受率很高,則有加速推論的趨勢。

減少草稿符記
草擬的權杖越少,接受率就越高,因為位置越接近初始提示的權杖越準確。不過,由於系統只會草擬幾個權杖,因此即使使用速度較快的草擬模型,也無法大幅提升速度。

幸好,您不必在 transformers 中實驗找出最適合您用途的值,因為您可以將 num_assistant_tokens_schedule 設為「heuristic」,系統會在執行階段自動調整草擬權杖的數量:
- 接受所有權杖:由於草稿生成器對提示的解讀相當準確,因此將權杖數量增加 2 個。如果草擬的權杖數量增加,且這些權杖也獲得接受,速度可能會加快。
- 任何權杖遭拒:如有任何權杖遭拒,請將草稿權杖數量減少 1 個。減少權杖數量可避免浪費過多草稿,因為目標模型會繼續拒絕大部分權杖。
同樣地,您可以在 Drafter 中更新 num_assistant_tokens,藉此更新草稿權杖數量,如下所示:
# Update how many draft tokens are generated at the start of inference
assistant_model.generation_config.num_assistant_tokens = 4
# Update how the number of draft tokens are updated ("heuristic" for a dynamic schedule and "constant" for a constant schedule)
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"
# Run with MTP
outputs = target_model.generate(
**inputs,
assistant_model=assistant_model,
max_new_tokens=256,
do_sample=False,
)
# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then verified by a larger, more accurate model to quickly produce a high-quality output. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly speed up the inference process of large language models while maintaining or improving output quality.
在 Google Colab 中執行
在 GitHub 上查看來源