|
|
[Run in Google Colab]
|
|
|
[GitHub でソースを表示]
|
Gemma 4 モデルの推論速度を向上させるため、メインラインナップとともに、新しい一連の自己回帰型「ドラフター」モデルがリリースされました。ドラフト モデルは、プライマリ Gemma 4 モデル(「ターゲット」モデル)のみに依存するのではなく、ターゲット モデルが 1 つのトークンを処理する間に、複数のトークンを自己回帰的に予測します。この手法は、投機的デコードとも呼ばれます。
ドラフターが複数のドラフト トークンを予測すると、ターゲット モデルは提案されたドラフト トークンを検証するだけで済みます。検証は並行して行われるため、推論が大幅に高速化されます。これにより、ターゲット モデルが各トークンに対して行うフォワード パスの数が減ります。ドラフターは検証用のトークン シーケンスを生成するため、マルチトークン予測(MTP)ヘッドと呼ばれます。

Gemma 4 ファミリー向けにリリースされたドラフト モデルは小さく、ドラフト トークンの品質を向上させ、推論をさらに高速化するための機能強化がいくつか導入されています。たとえば、ターゲット モデルのアクティベーションと KV キャッシュを使用して、予測の精度を高めます。
これらの機能強化により、同様の品質を保証しながら、デコード速度が大幅に向上します。そのため、これらのチェックポイントは、低レイテンシのオンデバイス アプリケーションに最適です。
Python パッケージをインストールする
Gemma 4 モデルと Gemma 4 アシスタント モデルの実行に必要な Hugging Face ライブラリをインストールします。
# Install PyTorch & other librariespip install torch accelerate# Install the transformers librarypip install transformers
モデルを読み込む
各ターゲット モデル(Gemma 4 モデルのメインモデルの 1 つ)には、推論を高速化するアシスタントがあります。そのため、次の 2 つのモデルを読み込みます。
- ターゲット (例:
google/gemma-4-E2B-it): 完全な Gemma 4 ターゲット モデル - ドラフター (例:
google/gemma-4-E2B-it-assistant): 候補トークンを提案する軽量の 4 レイヤ MTP ドラフター
ドラフター は、予測するトークンの選択で大規模なモデルを支援するため、アシスタント と呼ばれることがよくあります。
transformers ライブラリを使用して、次のコード例に示すように、AutoProcessor クラスと AutoModelForCausalLM クラスを使用して processor と model のインスタンスを作成します。
TARGET_MODEL_ID = "google/gemma-4-E2B-it" # @param ["google/gemma-4-E2B-it","google/gemma-4-E4B-it", "google/gemma-4-31B-it", "google/gemma-4-26B-A4B-it"]
ASSISTANT_MODEL_ID = TARGET_MODEL_ID + "-assistant"
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
# Target Model
processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
target_model = AutoModelForCausalLM.from_pretrained(
TARGET_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# Assistant Model (the drafter)
assistant_model = AutoModelForCausalLM.from_pretrained(
ASSISTANT_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
[transformers] `torch_dtype` is deprecated! Use `dtype` instead! Loading weights: 0%| | 0/1951 [00:00<?, ?it/s] Loading weights: 0%| | 0/50 [00:00<?, ?it/s]
アシスタント付きの Gemma 4
transformers でアシスタントを使用するのは非常に簡単で、アシスタント モデルを model.generate 関数に渡す必要があります。
# Process inputs with the `target_model`
messages = [
{
"role": "user",
"content": "Explain the concepts of speculative decoding and MTP in 3 sentences."
}
]
input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=input_text, return_tensors="pt").to(target_model.device)
# `assistant_model=assistant_model` is all you need to enable MTP!
outputs = target_model.generate(
**inputs,
assistant_model=assistant_model,
max_new_tokens=256,
do_sample=False,
)
# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then quickly verified by a larger, more accurate model to produce a final, high-quality output much faster than decoding the large model alone. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly accelerate the inference speed of large language models while maintaining or improving output quality.
内部的には、次のようなプロセスになります。
- ドラフターは、自己回帰的に生成された N 個のトークンを提案します。
- ターゲット モデルは、1 回 のフォワード パスですべての N 個のトークンを検証します。
- 確率の高いドラフト トークンは受け入れられます。
- 確率の低いドラフト トークンは拒否されます。
- ターゲット モデルはフォワード パスを行うため、受け入れられたドラフト トークンの数に関係なく、常に 1 つのトークンを生成します。
ドラフト トークン
ドラフターは、ターゲット モデルが検証するトークンを任意の数だけ生成できます。ただし、ターゲット モデルは特定のトークンを拒否することもできます。拒否された場合、それ以降のトークンはすべて無視されます。

そのため、ドラフト トークンの数にさまざまな値を使用する場合のトレードオフを把握することが重要です。
ドラフト トークンが多い場合
多くのトークン(15 個など)をドラフトする場合、すべてのトークンが受け入れられるとは限りません。そのため、コンピューティング リソースが無駄になる可能性が高くなります。一方、受け入れ率が高い場合は、推論が高速化する傾向があります。

ドラフト トークンが少ない場合
ドラフト トークンが少ない場合、最初のプロンプトに近いトークンほど正確になるため、受け入れ率が高くなる傾向があります。ただし、ドラフトされるトークンが少ないため、ドラフター モデルの高速化による速度向上の効果は小さくなります。

幸いなことに、transformers でユースケースに最適な値を試す必要はありません。num_assistant_tokens_schedule を「heuristic」に設定すると、実行時にドラフト トークンの数が自動的に調整されます。
- すべてのトークンが受け入れられた場合 -- ドラフターはプロンプトに対して非常に正確であるため、ドラフトするトークンの数を 2 増やします。ドラフトされたトークンも受け入れられる場合は、ドラフトするトークンの数を増やすと速度が向上する可能性があります。
- トークンが拒否された場合 -- トークンが拒否された場合は、ドラフトするトークンの数を 1 減らします。トークンの数を減らすと、ターゲット モデルがほとんどのトークンを拒否し続けても、無駄になるドラフト トークンの数が少なくなります。
同様に、ドラフターの num_assistant_tokens を更新して、ドラフト トークンの数を更新できます。
# Update how many draft tokens are generated at the start of inference
assistant_model.generation_config.num_assistant_tokens = 4
# Update how the number of draft tokens are updated ("heuristic" for a dynamic schedule and "constant" for a constant schedule)
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"
# Run with MTP
outputs = target_model.generate(
**inputs,
assistant_model=assistant_model,
max_new_tokens=256,
do_sample=False,
)
# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then verified by a larger, more accurate model to quickly produce a high-quality output. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly speed up the inference process of large language models while maintaining or improving output quality.
[Run in Google Colab]
[GitHub でソースを表示]