|
|
Exécuter dans Google Colab
|
|
|
Afficher la source sur GitHub
|
Pour améliorer la vitesse d'inférence des modèles Gemma 4, une nouvelle série de modèles autorégressifs "brouillons" a été lancée en parallèle de la gamme principale. Au lieu de s'appuyer uniquement sur les modèles Gemma 4 principaux (appelés modèles "cibles"), le modèle brouillon prédit plusieurs jetons de manière autorégressive pendant le temps nécessaire au modèle cible pour n'en traiter qu'un seul. Cette technique est également appelée décodage spéculatif.
Une fois que le brouillon a prédit plusieurs jetons brouillons, le modèle cible n'a plus qu'à vérifier ces jetons brouillons suggérés. La vérification est effectuée en parallèle, ce qui accélère considérablement l'inférence. Cela réduit le nombre de passes avant que le modèle cible doit effectuer pour chaque jeton. Étant donné que notre brouillon génère une séquence de jetons à vérifier, nous l'appelons tête de prédiction multi-jetons (MTP).

Les modèles brouillons publiés pour la famille Gemma 4 sont petits et introduisent plusieurs améliorations pour améliorer la qualité des jetons brouillons et accélérer davantage l'inférence, comme l'utilisation des activations du modèle cible et du cache KV pour obtenir de meilleures prédictions.
Ces améliorations permettent d'accélérer considérablement le décodage tout en garantissant une qualité similaire, ce qui rend ces points de contrôle parfaits pour les applications à faible latence et sur appareil.
Installer des packages Python
Installez les bibliothèques Hugging Face requises pour exécuter le modèle d'assistant Gemma 4 et Gemma 4.
# Install PyTorch & other librariespip install torch accelerate# Install the transformers librarypip install transformers
Charger les modèles
Pour chaque modèle cible (l'un des principaux modèles du modèle Gemma 4), un assistant permet d'accélérer l'inférence. Vous chargerez donc deux modèles :
- Cible (par exemple,
google/gemma-4-E2B-it) : modèle cible Gemma 4 complet - Brouillon (par exemple,
google/gemma-4-E2B-it-assistant) : brouillon MTP léger à quatre couches qui propose des jetons candidats
Notez que le brouillon est souvent appelé assistant , car le modèle aide le modèle plus grand à choisir les jetons à prédire.
Utilisez les bibliothèques transformers pour créer une instance de processor et de model à l'aide des classes AutoProcessor et AutoModelForCausalLM, comme illustré dans l'exemple de code suivant :
TARGET_MODEL_ID = "google/gemma-4-E2B-it" # @param ["google/gemma-4-E2B-it","google/gemma-4-E4B-it", "google/gemma-4-31B-it", "google/gemma-4-26B-A4B-it"]
ASSISTANT_MODEL_ID = TARGET_MODEL_ID + "-assistant"
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
# Target Model
processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
target_model = AutoModelForCausalLM.from_pretrained(
TARGET_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# Assistant Model (the drafter)
assistant_model = AutoModelForCausalLM.from_pretrained(
ASSISTANT_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
[transformers] `torch_dtype` is deprecated! Use `dtype` instead! Loading weights: 0%| | 0/1951 [00:00<?, ?it/s] Loading weights: 0%| | 0/50 [00:00<?, ?it/s]
Gemma 4 avec l'assistant
L'utilisation d'un assistant dans transformers est heureusement assez simple et nécessite de transmettre le modèle d'assistant à la fonction model.generate :
# Process inputs with the `target_model`
messages = [
{
"role": "user",
"content": "Explain the concepts of speculative decoding and MTP in 3 sentences."
}
]
input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=input_text, return_tensors="pt").to(target_model.device)
# `assistant_model=assistant_model` is all you need to enable MTP!
outputs = target_model.generate(
**inputs,
assistant_model=assistant_model,
max_new_tokens=256,
do_sample=False,
)
# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then quickly verified by a larger, more accurate model to produce a final, high-quality output much faster than decoding the large model alone. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly accelerate the inference speed of large language models while maintaining or improving output quality.
En arrière-plan, le processus est le suivant :
- Le brouillon propose N jetons générés de manière autorégressive
- Le modèle cible vérifie tous les jetons N en une seule passe avant
- Les jetons brouillons avec des probabilités élevées sont acceptés
- Les jetons brouillons avec de faibles probabilités sont rejetés
- Étant donné que le modèle cible effectue une passe avant, il génère toujours un jeton par lui-même, quel que soit le nombre de jetons brouillons acceptés ou rejetés
Jetons brouillons
Le brouillon peut générer n'importe quel nombre de jetons que le modèle cible peut vérifier. Toutefois, le modèle cible peut toujours choisir de rejeter certains jetons. Dans ce cas, tous les jetons suivants sont ignorés.

Il est donc important de connaître le compromis lorsque vous utilisez différentes valeurs pour le nombre de jetons brouillons.
Plus de jetons brouillons
Lorsque vous rédigez de nombreux jetons (par exemple, 15), il est fort probable que tous les jetons ne soient pas acceptés. Par conséquent, le potentiel de calcul gaspillé est plus élevé. En revanche, cela a tendance à accélérer l'inférence lorsque le taux d'acceptation est élevé.

Moins de jetons brouillons
Lorsque vous rédigez moins de jetons, le taux d'acceptation a tendance à être plus élevé, car les jetons dont la position est plus proche de l'invite initiale sont plus précis. Toutefois, comme seuls quelques jetons sont rédigés, l'accélération que vous obtiendriez d'un modèle de brouillon plus rapide est réduite.

Heureusement, vous n'avez pas besoin d'expérimenter les meilleures valeurs pour votre cas d'utilisation dans transformers, car vous pouvez définir num_assistant_tokens_schedule sur "heuristic", ce qui adaptera automatiquement le nombre de jetons brouillons au moment de l'exécution :
- Tous les jetons acceptés : augmentez de deux le nombre de jetons à rédiger, car le brouillon est assez précis pour l'invite. L'augmentation du nombre de jetons rédigés peut entraîner une accélération si ces jetons sont également acceptés.
- Tous les jetons rejetés : si des jetons sont rejetés, réduisez de 1 le nombre de jetons à rédiger. La réduction du nombre de jetons permet de ne pas gaspiller trop de brouillons si le modèle cible continue de rejeter la plupart des jetons.
De même, vous pouvez mettre à jour le nombre de jetons brouillons en mettant à jour num_assistant_tokens dans le brouillon comme suit :
# Update how many draft tokens are generated at the start of inference
assistant_model.generation_config.num_assistant_tokens = 4
# Update how the number of draft tokens are updated ("heuristic" for a dynamic schedule and "constant" for a constant schedule)
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"
# Run with MTP
outputs = target_model.generate(
**inputs,
assistant_model=assistant_model,
max_new_tokens=256,
do_sample=False,
)
# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then verified by a larger, more accurate model to quickly produce a high-quality output. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly speed up the inference process of large language models while maintaining or improving output quality.
Exécuter dans Google Colab
Afficher la source sur GitHub