Previsione multi-token (MTP) di Gemma 4 utilizzando Hugging Face Transformers

Visualizza su ai.google.dev Esegui in Google Colab Esegui in Kaggle Apri in Vertex AI Visualizza l'origine su GitHub

Per migliorare la velocità di inferenza dei modelli Gemma 4, insieme alla gamma principale è stata rilasciata una nuova serie di modelli autoregressivi "drafter". Anziché fare affidamento esclusivamente sui modelli Gemma 4 principali (denominati modelli "target"), il modello di bozza prevede diversi token autoregressivamente nel tempo necessario al modello target per elaborarne uno solo. Questa tecnica è nota anche come decodifica speculativa.

Dopo che il drafter ha previsto più token di bozza, il modello target deve solo verificare i token di bozza suggeriti. La verifica viene eseguita in parallelo, velocizzando notevolmente l'inferenza. Riduce il numero di passaggi in avanti che il modello target deve eseguire per ogni token. Poiché il nostro drafter genera una sequenza di token per la verifica, lo chiamiamo head di previsione multi-token (MTP).

png

I modelli di bozza rilasciati per la famiglia Gemma 4 sono piccoli e introducono diversi miglioramenti per migliorare la qualità dei token di bozza e velocizzare ulteriormente l'inferenza, ad esempio utilizzando le attivazioni del modello target e la cache KV per ottenere previsioni migliori.

Questi miglioramenti comportano un notevole aumento della velocità di decodifica, garantendo al contempo una qualità simile, il che rende questi checkpoint perfetti per le applicazioni a bassa latenza e on-device.

Installare i pacchetti Python

Installa le librerie Hugging Face necessarie per eseguire il modello Gemma 4 e il modello di assistenza Gemma 4.

# Install PyTorch & other libraries
pip install torch accelerate

# Install the transformers library
pip install transformers

Caricare i modelli

Per ogni modello target (uno dei modelli principali del modello Gemma 4), è disponibile un assistente che aiuta ad accelerare l'inferenza. Pertanto, caricherai due modelli:

  • Target (ad es. google/gemma-4-E2B-it): il modello target Gemma 4 completo
  • Drafter (ad es. google/gemma-4-E2B-it-assistant): il drafter MTP leggero a 4 livelli che propone token candidati

Tieni presente che il drafter viene spesso chiamato assistente perché il modello aiuta il modello più grande a scegliere i token da prevedere.

Utilizza le librerie transformers per creare un'istanza di un processor e di un model utilizzando le classi AutoProcessor e AutoModelForCausalLM, come mostrato nel seguente esempio di codice:

TARGET_MODEL_ID = "google/gemma-4-E2B-it" # @param ["google/gemma-4-E2B-it","google/gemma-4-E4B-it", "google/gemma-4-31B-it", "google/gemma-4-26B-A4B-it"]
ASSISTANT_MODEL_ID = TARGET_MODEL_ID + "-assistant"
import torch
from transformers import AutoProcessor, AutoModelForCausalLM

# Target Model
processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
target_model = AutoModelForCausalLM.from_pretrained(
    TARGET_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Assistant Model (the drafter)
assistant_model = AutoModelForCausalLM.from_pretrained(
    ASSISTANT_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
[transformers] `torch_dtype` is deprecated! Use `dtype` instead!
Loading weights:   0%|          | 0/1951 [00:00<?, ?it/s]
Loading weights:   0%|          | 0/50 [00:00<?, ?it/s]

Gemma 4 con l'assistente

Fortunatamente, l'utilizzo di un assistente in transformers è abbastanza semplice e richiede di passare il modello di assistente alla funzione model.generate:

# Process inputs with the `target_model`
messages = [
    {
        "role": "user",
        "content": "Explain the concepts of speculative decoding and MTP in 3 sentences."
    }
]
input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=input_text, return_tensors="pt").to(target_model.device)

# `assistant_model=assistant_model` is all you need to enable MTP!
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then quickly verified by a larger, more accurate model to produce a final, high-quality output much faster than decoding the large model alone. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly accelerate the inference speed of large language models while maintaining or improving output quality.

Dietro le quinte, la procedura è la seguente:

  • Il drafter propone N token generati autoregressivamente
  • Il modello target verifica tutti gli N token in un passaggio in avanti
  • I token di bozza con probabilità elevate vengono accettati
  • I token di bozza con probabilità basse vengono rifiutati
  • Poiché il modello target esegue un passaggio in avanti, genererà sempre 1 token da solo, indipendentemente dal numero di token di bozza accettati o rifiutati

Token di bozza

Il drafter può generare qualsiasi quantità di token da verificare per il modello target. Tuttavia, il modello target può comunque scegliere di rifiutare determinati token. In questo caso, tutti i token successivi vengono ignorati.

png

Pertanto, è importante conoscere il compromesso quando si utilizzano vari valori per il numero di token di bozza.

Più token di bozza

Quando crei molti token (ad esempio 15), è molto probabile che non tutti i token vengano accettati. Di conseguenza, esiste un potenziale maggiore di calcolo sprecato. Al contrario, tende ad accelerare l'inferenza quando il tasso di accettazione è elevato.

png

Meno token di bozza

Quando crei meno token, il tasso di accettazione tende a essere più alto perché i token più vicini alla richiesta iniziale sono più precisi. Tuttavia, poiché vengono creati solo pochi token, la velocità ottenuta da un modello di drafter più veloce viene ridotta.

png

Fortunatamente, non devi sperimentare i valori migliori per il tuo caso d'uso in transformers, perché puoi impostare num_assistant_tokens_schedule su "heuristic", che adatterà automaticamente il numero di token di bozza in fase di runtime:

  • Tutti i token accettati : aumenta di 2 il numero di token da creare, perché il drafter è abbastanza preciso per la richiesta. L'aumento del numero di token creati potrebbe comportare un aumento della velocità se anche questi token vengono accettati.
  • Qualsiasi token rifiutato : se vengono rifiutati dei token, riduci di 1 la quantità di token da creare. La riduzione del numero di token fa sì che non vengano sprecati troppi token di bozza se il modello target continua a rifiutare la maggior parte dei token.

Allo stesso modo, puoi aggiornare il numero di token di bozza aggiornando num_assistant_tokens nel drafter come segue:

# Update how many draft tokens are generated at the start of inference
assistant_model.generation_config.num_assistant_tokens = 4

# Update how the number of draft tokens are updated ("heuristic" for a dynamic schedule and "constant" for a constant schedule)
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"

# Run with MTP
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then verified by a larger, more accurate model to quickly produce a high-quality output. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly speed up the inference process of large language models while maintaining or improving output quality.