Gemma 4 – Multi-Token Prediction (MTP) mit Hugging Face Transformers

Auf ai.google.dev ansehen  In Google Colab ausführen In Kaggle ausführen In Vertex AI öffnen Quelle auf GitHub ansehen

Um die Inferenzgeschwindigkeit der Gemma 4-Modelle zu verbessern, wurde neben der Hauptreihe eine neue Reihe autoregressiver „Drafter“-Modelle veröffentlicht. Anstatt sich ausschließlich auf die primären Gemma 4-Modelle (die sogenannten „Zielmodelle“) zu verlassen, sagt das Draft-Modell mehrere Tokens autoregressiv in der Zeit voraus, die das Zielmodell für die Verarbeitung eines einzigen Tokens benötigt. Diese Technik wird auch als spekulatives Decodieren bezeichnet.

Nachdem das Drafter-Modell mehrere Vorschlags-Tokens vorhergesagt hat, muss das Zielmodell diese vorgeschlagenen Vorschlags-Tokens nur noch bestätigen. Die Überprüfung erfolgt parallel, wodurch die Inferenz erheblich beschleunigt wird. Dadurch wird die Anzahl der Forward-Passes reduziert, die das Zielmodell für jedes Token ausführen muss. Da unser Textersteller eine Folge von Tokens zur Überprüfung generiert, bezeichnen wir ihn als MTP-Head (Multi-Token Prediction).

png

Die für die Gemma 4-Familie veröffentlichten Modellentwürfe sind klein und bieten mehrere Verbesserungen, um die Qualität der erstellten Tokens zu verbessern und die Inferenz zu beschleunigen. Dazu werden beispielsweise die Aktivierungen des Zielmodells und der KV-Cache verwendet, um bessere Vorhersagen zu treffen.

Diese Verbesserungen führen zu einer erheblichen Beschleunigung der Dekodierung bei gleichbleibender Qualität. Daher eignen sich diese Checkpoints perfekt für Anwendungen mit geringer Latenz und On-Device-Anwendungen.

Python-Pakete installieren

Installieren Sie die Hugging Face-Bibliotheken, die zum Ausführen des Gemma 4- und des Gemma 4-Assistentenmodells erforderlich sind.

# Install PyTorch & other libraries
pip install torch accelerate

# Install the transformers library
pip install transformers

Modelle laden

Für jedes Zielmodell (eines der Hauptmodelle im Gemma 4-Modell) gibt es einen Assistenten, der die Inferenz beschleunigt. Sie laden also zwei Modelle:

  • Ziel (z.B. google/gemma-4-E2B-it): Das vollständige Gemma 4-Zielmodell
  • Drafter (z.B. google/gemma-4-E2B-it-assistant): Der einfache 4-Layer-MTP-Drafter, der Kandidatentokens vorschlägt

Der drafter wird oft als Assistent bezeichnet, da das Modell dem größeren Modell hilft, die vorherzusagenden Tokens auszuwählen.

Verwenden Sie die transformers-Bibliotheken, um eine Instanz von processor und model mit den Klassen AutoProcessor und AutoModelForCausalLM zu erstellen, wie im folgenden Codebeispiel gezeigt:

TARGET_MODEL_ID = "google/gemma-4-E2B-it" # @param ["google/gemma-4-E2B-it","google/gemma-4-E4B-it", "google/gemma-4-31B-it", "google/gemma-4-26B-A4B-it"]
ASSISTANT_MODEL_ID = TARGET_MODEL_ID + "-assistant"
import torch
from transformers import AutoProcessor, AutoModelForCausalLM

# Target Model
processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
target_model = AutoModelForCausalLM.from_pretrained(
    TARGET_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Assistant Model (the drafter)
assistant_model = AutoModelForCausalLM.from_pretrained(
    ASSISTANT_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
[transformers] `torch_dtype` is deprecated! Use `dtype` instead!
Loading weights:   0%|          | 0/1951 [00:00<?, ?it/s]
Loading weights:   0%|          | 0/50 [00:00<?, ?it/s]

Gemma 4 mit Assistant

Die Verwendung eines Assistenten in transformers ist zum Glück ganz einfach. Sie müssen das Assistentenmodell an die Funktion model.generate übergeben:

# Process inputs with the `target_model`
messages = [
    {
        "role": "user",
        "content": "Explain the concepts of speculative decoding and MTP in 3 sentences."
    }
]
input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=input_text, return_tensors="pt").to(target_model.device)

# `assistant_model=assistant_model` is all you need to enable MTP!
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then quickly verified by a larger, more accurate model to produce a final, high-quality output much faster than decoding the large model alone. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly accelerate the inference speed of large language models while maintaining or improving output quality.

So läuft der Prozess ab:

  • Der Textersteller schlägt N Tokens vor, die autoregressiv generiert werden.
  • Das Zielmodell überprüft alle N Tokens in einem Forward-Pass.
  • Entworfene Tokens mit hoher Wahrscheinlichkeit werden akzeptiert
  • Entworfene Tokens mit geringen Wahrscheinlichkeiten werden abgelehnt
  • Da das Zielmodell einen Forward-Pass durchführt, wird immer ein Token generiert, unabhängig davon, wie viele entworfenen Tokens akzeptiert oder abgelehnt wurden.

Entwurfstokens

Der Drafter kann eine beliebige Anzahl von Tokens für das Zielmodell generieren, die überprüft werden sollen. Das Zielmodell kann jedoch weiterhin bestimmte Tokens ablehnen. Wenn das der Fall ist, werden alle nachfolgenden Tokens ignoriert.

png

Daher ist es wichtig, die Auswirkungen der Verwendung verschiedener Werte für die Anzahl der erstellten Tokens zu kennen.

Mehr Draft-Tokens

Wenn Sie viele Tokens entwerfen (z. B. 15), ist es sehr wahrscheinlich, dass nicht alle akzeptiert werden. Daher besteht ein höheres Potenzial für verschwendete Rechenleistung. Wenn die Akzeptanzrate hoch ist, kann die Inferenz dadurch beschleunigt werden.

png

Weniger Entwurfstokens

Wenn Sie weniger Tokens erstellen, ist die Akzeptanzrate in der Regel höher, da Tokens, die näher am ursprünglichen Prompt liegen, genauer sind. Da jedoch nur wenige Tokens erstellt werden, ist die Geschwindigkeitssteigerung durch ein schnelleres Modell geringer.

png

Glücklicherweise müssen Sie nicht mit den besten Werten für Ihren Anwendungsfall in transformers experimentieren, da Sie num_assistant_tokens_schedule auf „heuristic“ setzen können. Dadurch wird die Anzahl der erstellten Tokens zur Laufzeit automatisch angepasst:

  • Alle akzeptierten Tokens: Erhöhe die Anzahl der zu erstellenden Tokens um 2, da der Textersteller sehr genau auf den Prompt reagiert. Wenn Sie die Anzahl der erstellten Tokens erhöhen, kann sich die Geschwindigkeit erhöhen, wenn diese Tokens auch akzeptiert werden.
  • Abgelehnte Tokens: Wenn Tokens abgelehnt werden, reduzieren Sie die Anzahl der zu erstellenden Tokens um 1. Durch die Reduzierung der Anzahl der Tokens wird verhindert, dass zu viele Entwürfe verschwendet werden, wenn das Zielmodell weiterhin die meisten Tokens ablehnt.

Ebenso können Sie die Anzahl der Draft-Tokens aktualisieren, indem Sie num_assistant_tokens im Drafter so aktualisieren:

# Update how many draft tokens are generated at the start of inference
assistant_model.generation_config.num_assistant_tokens = 4

# Update how the number of draft tokens are updated ("heuristic" for a dynamic schedule and "constant" for a constant schedule)
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"

# Run with MTP
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then verified by a larger, more accurate model to quickly produce a high-quality output. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly speed up the inference process of large language models while maintaining or improving output quality.