Prognozowanie wielu tokenów (MTP) w przypadku modelu Gemma 4 za pomocą biblioteki Hugging Face Transformers

Wyświetl na ai.google.dev Uruchom w Google Colab Uruchom w Kaggle Otwórz w Vertex AI Wyświetl źródło na GitHubie

Aby zwiększyć szybkość wnioskowania modeli Gemma 4, oprócz głównej serii udostępniliśmy nową serię autoregresywnych modeli „wersji roboczych”. Zamiast polegać wyłącznie na podstawowych modelach Gemma 4 (zwanych „modelami docelowymi”), model roboczy przewiduje autoregresywnie kilka tokenów w czasie, w którym model docelowy przetwarza tylko jeden token. Ta metoda jest też nazywana dekodowaniem spekulacyjnym.

Gdy model roboczy przewidzi wiele tokenów roboczych, model docelowy musi tylko zweryfikować te sugerowane tokeny. Weryfikacja jest przeprowadzana równolegle, co znacznie przyspiesza wnioskowanie. Zmniejsza to liczbę przejść do przodu, które model docelowy musi wykonać dla każdego tokena. Nasz generator tworzy sekwencję tokenów do weryfikacji, dlatego nazywamy go głowicą prognozowania wielu tokenów (Multi-Token Prediction, MTP).

png

Wersje robocze modeli z rodziny Gemma 4 są niewielkie i zawierają kilka ulepszeń, które zwiększają jakość generowanych tokenów i przyspieszają wnioskowanie. Wykorzystują np. aktywacje modelu docelowego i pamięć podręczną klucz-wartość, aby uzyskiwać lepsze prognozy.

Te ulepszenia znacznie przyspieszają dekodowanie przy zachowaniu podobnej jakości, dzięki czemu te punkty kontrolne idealnie nadają się do aplikacji o niskim opóźnieniu i aplikacji na urządzeniach.

Instalowanie pakietów Pythona

Zainstaluj biblioteki Hugging Face wymagane do uruchomienia modelu Gemma 4 i Gemma 4 Assistant.

# Install PyTorch & other libraries
pip install torch accelerate

# Install the transformers library
pip install transformers

Wczytywanie modeli

W przypadku każdego modelu docelowego (jednego z głównych modeli w modelu Gemma 4) dostępny jest asystent, który przyspiesza wnioskowanie. W związku z tym załadujesz 2 modele:

  • Model docelowy (np. google/gemma-4-E2B-it): pełny model docelowy Gemma 4
  • Drafter (np.google/gemma-4-E2B-it-assistant): lekki 4-warstwowy model MTP, który proponuje kandydatów na tokeny.

Pamiętaj, że model pomocniczy jest często nazywany asystentem, ponieważ pomaga większemu modelowi w wybieraniu tokenów do przewidywania.

Użyj bibliotek transformers, aby utworzyć instancję processormodel za pomocą klas AutoProcessorAutoModelForCausalLM, jak pokazano w tym przykładowym kodzie:

TARGET_MODEL_ID = "google/gemma-4-E2B-it" # @param ["google/gemma-4-E2B-it","google/gemma-4-E4B-it", "google/gemma-4-31B-it", "google/gemma-4-26B-A4B-it"]
ASSISTANT_MODEL_ID = TARGET_MODEL_ID + "-assistant"
import torch
from transformers import AutoProcessor, AutoModelForCausalLM

# Target Model
processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
target_model = AutoModelForCausalLM.from_pretrained(
    TARGET_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Assistant Model (the drafter)
assistant_model = AutoModelForCausalLM.from_pretrained(
    ASSISTANT_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
[transformers] `torch_dtype` is deprecated! Use `dtype` instead!
Loading weights:   0%|          | 0/1951 [00:00<?, ?it/s]
Loading weights:   0%|          | 0/50 [00:00<?, ?it/s]

Gemma 4 z Asystentem

Korzystanie z asystenta w transformers jest na szczęście dość proste i wymaga przekazania modelu asystenta do funkcji model.generate:

# Process inputs with the `target_model`
messages = [
    {
        "role": "user",
        "content": "Explain the concepts of speculative decoding and MTP in 3 sentences."
    }
]
input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=input_text, return_tensors="pt").to(target_model.device)

# `assistant_model=assistant_model` is all you need to enable MTP!
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then quickly verified by a larger, more accurate model to produce a final, high-quality output much faster than decoding the large model alone. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly accelerate the inference speed of large language models while maintaining or improving output quality.

Proces ten wygląda następująco:

  • Osoba przygotowująca propozycję proponuje N tokenów wygenerowanych autoregresywnie.
  • Model docelowy weryfikuje wszystkie N tokenów w jednym przejściu do przodu.
  • Zaakceptowane zostaną wygenerowane tokeny o wysokim prawdopodobieństwie.
  • Tokeny o niskim prawdopodobieństwie są odrzucane.
  • Model docelowy wykonuje przejście w przód, więc zawsze generuje 1 token samodzielnie, niezależnie od tego, ile tokenów roboczych zostało zaakceptowanych lub odrzuconych.

Tokeny wersji roboczej

Wersja robocza może wygenerować dowolną liczbę tokenów dla modelu docelowego, aby je zweryfikować. Model docelowy może jednak odrzucać niektóre tokeny. Gdy to nastąpi, wszystkie tokeny po nim zostaną zignorowane.

png

Dlatego ważne jest, aby znać kompromis przy używaniu różnych wartości liczby wygenerowanych tokenów.

Więcej tokenów wersji roboczej

Jeśli utworzysz wiele tokenów (np. 15), istnieje duże prawdopodobieństwo, że nie wszystkie zostaną zaakceptowane. Dlatego istnieje większe ryzyko zmarnowania zasobów obliczeniowych. Z drugiej strony, gdy współczynnik akceptacji jest wysoki, może przyspieszyć wnioskowanie.

png

Mniej tokenów wersji roboczej

Gdy tworzysz mniej tokenów, współczynnik akceptacji jest zwykle wyższy, ponieważ tokeny, które są bliżej początkowego promptu, są dokładniejsze. Jednak ponieważ generowanych jest tylko kilka tokenów, przyspieszenie, jakie można uzyskać dzięki szybszemu modelowi, jest mniejsze.

png

Na szczęście nie musisz eksperymentować z najlepszymi wartościami w swoim przypadku użycia w transformers, ponieważ możesz ustawić num_assistant_tokens_schedule na „heuristic”, co spowoduje automatyczne dostosowanie liczby wygenerowanych tokenów w czasie działania:

  • Wszystkie tokeny zaakceptowane – zwiększ liczbę tokenów do wersji roboczej o 2, ponieważ generator jest dość dokładny w przypadku tego promptu. Zwiększenie liczby wygenerowanych tokenów może przyspieszyć proces, jeśli zostaną one zaakceptowane.
  • Odrzucone tokeny – jeśli jakiekolwiek tokeny zostaną odrzucone, zmniejsz liczbę tokenów do utworzenia o 1. Zmniejszenie liczby tokenów sprawia, że nie marnuje się zbyt wielu wersji roboczych, jeśli model docelowy nadal odrzuca większość tokenów.

Podobnie możesz zaktualizować liczbę tokenów wersji roboczej, aktualizując num_assistant_tokens w narzędziu do tworzenia wersji roboczych w ten sposób:

# Update how many draft tokens are generated at the start of inference
assistant_model.generation_config.num_assistant_tokens = 4

# Update how the number of draft tokens are updated ("heuristic" for a dynamic schedule and "constant" for a constant schedule)
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"

# Run with MTP
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then verified by a larger, more accurate model to quickly produce a high-quality output. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly speed up the inference process of large language models while maintaining or improving output quality.