Predicción de varios tokens (MTP) de Gemma 4 con Transformers de Hugging Face

Ver en ai.google.dev Ejecutar en Google Colab Ejecutar en Kaggle Abrir en Vertex AI Ver el código fuente en GitHub

Para mejorar la velocidad de inferencia de los modelos de Gemma 4, se lanzó una nueva serie de modelos “borrador” autorregresivos junto con la línea principal. En lugar de depender únicamente de los modelos principales de Gemma 4 (denominados modelos “objetivo”), el modelo de borrador predice varios tokens de forma autorregresiva en el tiempo que le lleva al modelo objetivo procesar solo uno. Esta técnica también se conoce como decodificación especulativa.

Después de que el redactor predice varios tokens de borrador, el modelo objetivo solo tiene que verificar esos tokens de borrador sugeridos. La verificación se realiza en paralelo, lo que acelera drásticamente la inferencia. Reduce la cantidad de pases hacia adelante que el modelo objetivo debe realizar para cada token. Dado que nuestro redactor genera una secuencia de tokens para la verificación, nos referimos a él como el encabezado de predicción de varios tokens (MTP).

png

Los modelos de borrador lanzados para la familia Gemma 4 son pequeños y presentan varias mejoras para optimizar la calidad de los tokens borrador y acelerar aún más la inferencia, como el uso de las activaciones del modelo objetivo y el KV-cache para obtener mejores predicciones.

Estas mejoras generan aceleraciones significativas en la velocidad de decodificación y garantizan una calidad similar, lo que hace que estos puntos de control sean perfectos para aplicaciones de baja latencia e integradas en el dispositivo.

Instala paquetes de Python

Instala las bibliotecas de Hugging Face necesarias para ejecutar el modelo Gemma 4 y el asistente de Gemma 4.

# Install PyTorch & other libraries
pip install torch accelerate

# Install the transformers library
pip install transformers

Carga los modelos

Para cada modelo objetivo (uno de los modelos principales del modelo Gemma 4), hay un asistente que ayuda a acelerar la inferencia. Por lo tanto, cargarás dos modelos:

  • Objetivo (p. ej., google/gemma-4-E2B-it): Es el modelo objetivo completo de Gemma 4.
  • Drafter (p. ej., google/gemma-4-E2B-it-assistant): Es el borrador ligero de MTP de 4 capas que propone tokens candidatos.

Ten en cuenta que, a menudo, el borrador se conoce como el asistente, ya que el modelo ayuda al modelo más grande a elegir qué tokens predecir.

Usa las bibliotecas de transformers para crear una instancia de processor y model con las clases AutoProcessor y AutoModelForCausalLM, como se muestra en el siguiente ejemplo de código:

TARGET_MODEL_ID = "google/gemma-4-E2B-it" # @param ["google/gemma-4-E2B-it","google/gemma-4-E4B-it", "google/gemma-4-31B-it", "google/gemma-4-26B-A4B-it"]
ASSISTANT_MODEL_ID = TARGET_MODEL_ID + "-assistant"
import torch
from transformers import AutoProcessor, AutoModelForCausalLM

# Target Model
processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
target_model = AutoModelForCausalLM.from_pretrained(
    TARGET_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Assistant Model (the drafter)
assistant_model = AutoModelForCausalLM.from_pretrained(
    ASSISTANT_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
[transformers] `torch_dtype` is deprecated! Use `dtype` instead!
Loading weights:   0%|          | 0/1951 [00:00<?, ?it/s]
Loading weights:   0%|          | 0/50 [00:00<?, ?it/s]

Gemma 4 con el Asistente

Afortunadamente, usar un asistente en transformers es bastante sencillo y requiere que pases el modelo del asistente a la función model.generate:

# Process inputs with the `target_model`
messages = [
    {
        "role": "user",
        "content": "Explain the concepts of speculative decoding and MTP in 3 sentences."
    }
]
input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=input_text, return_tensors="pt").to(target_model.device)

# `assistant_model=assistant_model` is all you need to enable MTP!
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then quickly verified by a larger, more accurate model to produce a final, high-quality output much faster than decoding the large model alone. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly accelerate the inference speed of large language models while maintaining or improving output quality.

Detrás de escena, el proceso es el siguiente:

  • El borrador propone N tokens generados de forma autorregresiva.
  • El modelo objetivo verifica todos los N tokens en un pase hacia adelante.
  • Se aceptan los tokens redactados con probabilidades altas.
  • Se rechazan los tokens redactados con probabilidades bajas
  • Dado que el modelo objetivo realiza un pase hacia adelante, siempre generará 1 token por sí mismo, independientemente de cuántos tokens borrador se hayan aceptado o rechazado.

Tokens de borrador

El redactor puede generar cualquier cantidad de tokens para que el modelo objetivo los verifique. Sin embargo, el modelo de destino aún puede rechazar ciertos tokens. Cuando lo hace, se ignoran todos los tokens posteriores.

png

Por lo tanto, es importante conocer la compensación cuando se usan varios valores para la cantidad de tokens borrador.

Más tokens de borrador

Cuando creas muchos tokens (por ejemplo, 15), es muy probable que no se acepten todos. Por lo tanto, existe un mayor potencial de desperdicio de recursos de procesamiento. En cambio, sí tiende a acelerar la inferencia cuando la tasa de aceptación es alta.

png

Menos tokens de borrador

Cuando redactas menos tokens, la tasa de aceptación tiende a ser más alta, ya que los tokens que están más cerca en posición de la instrucción inicial son más precisos. Sin embargo, dado que solo se redactan unos pocos tokens, se reduce la aceleración que obtendrías de un modelo de redacción más rápido.

png

Afortunadamente, no tienes que experimentar con los mejores valores para tu caso de uso en transformers, ya que puedes establecer num_assistant_tokens_schedule en "heurístico", lo que adaptará automáticamente la cantidad de tokens creados en tiempo de ejecución:

  • Se aceptaron todos los tokens: Aumenta en 2 la cantidad de tokens para el borrador, ya que el redactor es bastante preciso para la instrucción. Aumentar la cantidad de tokens creados podría acelerar el proceso si también se aceptan esos tokens.
  • Se rechazó algún token: Si se rechaza algún token, reduce en 1 la cantidad de tokens para generar borradores. Reducir la cantidad de tokens hace que no se desperdicien demasiados borradores si el modelo objetivo sigue rechazando la mayoría de los tokens.

Del mismo modo, puedes actualizar la cantidad de tokens de borrador actualizando num_assistant_tokens en el redactor de la siguiente manera:

# Update how many draft tokens are generated at the start of inference
assistant_model.generation_config.num_assistant_tokens = 4

# Update how the number of draft tokens are updated ("heuristic" for a dynamic schedule and "constant" for a constant schedule)
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"

# Run with MTP
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then verified by a larger, more accurate model to quickly produce a high-quality output. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly speed up the inference process of large language models while maintaining or improving output quality.