Previsão de vários tokens (MTP) do Gemma 4 usando Transformers do Hugging Face

Acessar em ai.google.dev Executar no Google Colab Executar no Kaggle Abrir no Vertex AI Conferir a origem no GitHub

Para melhorar a velocidade de inferência dos modelos Gemma 4, uma nova série de modelos autorregressivos de "rascunho" foi lançada junto com a linha principal. Em vez de depender apenas dos modelos principais do Gemma 4 (referidos como modelos "de destino"), o modelo de rascunho prevê vários tokens autorregressivamente no tempo que o modelo de destino leva para processar apenas um. Essa técnica também é conhecida como decodificação especulativa.

Depois que o rascunho prevê vários tokens, o modelo de destino só precisa verificar os tokens de rascunho sugeridos. A verificação é feita em paralelo, acelerando drasticamente a inferência. Ela reduz o número de passagens diretas que o modelo de destino precisa fazer para cada token. Como nosso rascunho gera uma sequência de tokens para verificação, nos referimos a ele como o cabeçalho de previsão de vários tokens (MTP, na sigla em inglês).

png

Os modelos de rascunho lançados para a família Gemma 4 são pequenos e introduzem várias melhorias para melhorar a qualidade dos tokens rascunhados e acelerar ainda mais a inferência, como usar as ativações do modelo de destino e o KV-cache para fazer melhores previsões.

Essas melhorias resultam em acelerações significativas na velocidade de decodificação, garantindo uma qualidade semelhante, o que torna esses pontos de verificação perfeitos para aplicativos de baixa latência e no dispositivo.

Instalar pacotes Python

Instale as bibliotecas do Hugging Face necessárias para executar o modelo de assistente do Gemma 4 e do Gemma 4.

# Install PyTorch & other libraries
pip install torch accelerate

# Install the transformers library
pip install transformers

Carregar os modelos

Para cada modelo de destino (um dos principais modelos do Gemma 4), há um assistente para ajudar a acelerar a inferência. Assim, você vai carregar dois modelos:

  • Destino (por exemplo, google/gemma-4-E2B-it): o modelo de destino completo do Gemma 4
  • Rascunho (por exemplo, google/gemma-4-E2B-it-assistant): o rascunho MTP leve de 4 camadas que propõe tokens candidatos

O rascunho é geralmente chamado de assistente porque o modelo ajuda o modelo maior a escolher quais tokens prever.

Use as bibliotecas transformers para criar uma instância de um processor e model usando as classes AutoProcessor e AutoModelForCausalLM, conforme mostrado no exemplo de código a seguir:

TARGET_MODEL_ID = "google/gemma-4-E2B-it" # @param ["google/gemma-4-E2B-it","google/gemma-4-E4B-it", "google/gemma-4-31B-it", "google/gemma-4-26B-A4B-it"]
ASSISTANT_MODEL_ID = TARGET_MODEL_ID + "-assistant"
import torch
from transformers import AutoProcessor, AutoModelForCausalLM

# Target Model
processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
target_model = AutoModelForCausalLM.from_pretrained(
    TARGET_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Assistant Model (the drafter)
assistant_model = AutoModelForCausalLM.from_pretrained(
    ASSISTANT_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
[transformers] `torch_dtype` is deprecated! Use `dtype` instead!
Loading weights:   0%|          | 0/1951 [00:00<?, ?it/s]
Loading weights:   0%|          | 0/50 [00:00<?, ?it/s]

Gemma 4 com o assistente

O uso de um assistente em transformers é bastante simples e exige que você transmita o modelo de assistente para a função model.generate:

# Process inputs with the `target_model`
messages = [
    {
        "role": "user",
        "content": "Explain the concepts of speculative decoding and MTP in 3 sentences."
    }
]
input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=input_text, return_tensors="pt").to(target_model.device)

# `assistant_model=assistant_model` is all you need to enable MTP!
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then quickly verified by a larger, more accurate model to produce a final, high-quality output much faster than decoding the large model alone. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly accelerate the inference speed of large language models while maintaining or improving output quality.

Nos bastidores, o processo é o seguinte:

  • O rascunho propõe N tokens gerados autorregressivamente
  • O modelo de destino verifica todos os N tokens em uma passagem direta
  • Tokens rascunhados com altas probabilidades são aceitos
  • Tokens rascunhados com baixa probabilidade são rejeitados
  • Como o modelo de destino faz uma passagem direta, ele sempre gera 1 token por conta própria, independentemente de quantos tokens rascunhados foram aceitos ou rejeitados

Tokens de rascunho

O rascunho pode gerar qualquer quantidade de tokens para o modelo de destino verificar. No entanto, o modelo de destino ainda pode optar por rejeitar determinados tokens. Quando isso acontece, todos os tokens depois disso são ignorados.

png

Portanto, é importante conhecer a compensação ao usar vários valores para o número de tokens rascunhados.

Mais tokens de rascunho

Quando você rascunha muitos tokens (por exemplo, 15), há uma grande chance de que nem todos sejam aceitos. Assim, há um potencial maior de computação desperdiçada. Por outro lado, ele tem uma tendência a acelerar a inferência quando a taxa de aceitação é alta.

png

Menos tokens de rascunho

Quando você rascunha menos tokens, a taxa de aceitação tende a ser maior, já que os tokens mais próximos da solicitação inicial são mais precisos. No entanto, como apenas alguns tokens são rascunhados, a aceleração que você receberia de um modelo de rascunho mais rápido é reduzida.

png

Felizmente, você não precisa fazer experimentos com os melhores valores para seu caso de uso em transformers, já que pode definir num_assistant_tokens_schedule como "heurístico", que vai adaptar automaticamente o número de tokens rascunhados no tempo de execução:

  • Todos os tokens aceitos : aumente o número de tokens a serem rascunhados em 2, já que o rascunho é bastante preciso para a solicitação. Aumentar o número de tokens rascunhados pode resultar em uma aceleração se esses tokens também forem aceitos.
  • Tokens rejeitados : se algum token for rejeitado, reduza a quantidade de tokens a serem rascunhados em 1. Reduzir o número de tokens faz com que não haja muitos rascunhos desperdiçados se o modelo de destino continuar rejeitando a maioria dos tokens.

Da mesma forma, é possível atualizar o número de tokens de rascunho atualizando num_assistant_tokens no rascunho, desta forma:

# Update how many draft tokens are generated at the start of inference
assistant_model.generation_config.num_assistant_tokens = 4

# Update how the number of draft tokens are updated ("heuristic" for a dynamic schedule and "constant" for a constant schedule)
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"

# Run with MTP
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then verified by a larger, more accurate model to quickly produce a high-quality output. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly speed up the inference process of large language models while maintaining or improving output quality.