Hugging Face Transformers ile Gemma 4 Çoklu Jeton Tahmini (MTP)

ai.google.dev'de görüntüle Google Colab'de çalıştırma Kaggle'da çalıştırma Vertex AI'da aç Kaynağı GitHub'da görüntüleyin

Gemma 4 modellerinin çıkarım hızını artırmak için ana serinin yanı sıra yeni bir otoregresif “taslak oluşturucu” model serisi yayınlandı. Taslak modeli, yalnızca birincil Gemma 4 modellerine (hedef modeller olarak adlandırılır) güvenmek yerine, hedef modelin yalnızca bir jetonu işlemesi için gereken sürede birkaç jetonu otomatik regresif olarak tahmin eder. Bu teknik, spekülatif kod çözme olarak da bilinir.

Taslak oluşturucu, birden fazla taslak jetonu tahmin ettikten sonra hedef modelin artık yalnızca önerilen taslak jetonları doğrulaması gerekir. Doğrulama paralel olarak yapılır ve çıkarım önemli ölçüde hızlandırılır. Bu sayede, hedef modelin her jeton için yapması gereken iletme geçişlerinin sayısı azalır. Taslak oluşturucumuz, doğrulama için bir dizi jeton oluşturduğundan buna Çok Jetonlu Tahmin (MTP) başlığı diyoruz.

png

Gemma 4 ailesi için yayınlanan taslak modeller küçüktür ve taslak oluşturulan jetonların kalitesini artırmak, çıkarımı daha da hızlandırmak için çeşitli geliştirmeler sunar. Örneğin, daha iyi tahminler elde etmek için hedef model etkinleştirmelerini ve KV önbelleğini kullanır.

Bu geliştirmeler, benzer kaliteyi garanti ederken önemli ölçüde kod çözme hızlanması sağlar. Bu nedenle, bu kontrol noktaları düşük gecikmeli ve cihaz üzerinde uygulamalar için mükemmeldir.

Python paketlerini yükleme

Gemma 4 ve Gemma 4 asistan modelini çalıştırmak için gereken Hugging Face kitaplıklarını yükleyin.

# Install PyTorch & other libraries
pip install torch accelerate

# Install the transformers library
pip install transformers

Modelleri Yükleme

Her hedef model (Gemma 4 modelindeki ana modellerden biri) için çıkarım sürecini hızlandırmaya yardımcı olan bir asistan vardır. Bu nedenle, iki model yükleyeceksiniz:

  • Hedef (ör. google/gemma-4-E2B-it): Tam Gemma 4 hedef modeli
  • Drafter (ör. google/gemma-4-E2B-it-assistant): Aday jetonları öneren, 4 katmanlı hafif MTP taslağı

Model, hangi jetonların tahmin edileceğini seçme konusunda daha büyük modele yardımcı olduğundan taslak oluşturucunun genellikle asistan olarak adlandırıldığını unutmayın.

Aşağıdaki kod örneğinde gösterildiği gibi transformers kitaplıklarını kullanarak AutoProcessor ve AutoModelForCausalLM sınıflarını kullanarak processor ve model örneği oluşturun:

TARGET_MODEL_ID = "google/gemma-4-E2B-it" # @param ["google/gemma-4-E2B-it","google/gemma-4-E4B-it", "google/gemma-4-31B-it", "google/gemma-4-26B-A4B-it"]
ASSISTANT_MODEL_ID = TARGET_MODEL_ID + "-assistant"
import torch
from transformers import AutoProcessor, AutoModelForCausalLM

# Target Model
processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
target_model = AutoModelForCausalLM.from_pretrained(
    TARGET_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Assistant Model (the drafter)
assistant_model = AutoModelForCausalLM.from_pretrained(
    ASSISTANT_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
[transformers] `torch_dtype` is deprecated! Use `dtype` instead!
Loading weights:   0%|          | 0/1951 [00:00<?, ?it/s]
Loading weights:   0%|          | 0/50 [00:00<?, ?it/s]

Asistan ile Gemma 4

transformers içinde bir asistan kullanmak oldukça kolaydır ve asistan modelini model.generate işlevine iletmeniz gerekir:

# Process inputs with the `target_model`
messages = [
    {
        "role": "user",
        "content": "Explain the concepts of speculative decoding and MTP in 3 sentences."
    }
]
input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=input_text, return_tensors="pt").to(target_model.device)

# `assistant_model=assistant_model` is all you need to enable MTP!
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then quickly verified by a larger, more accurate model to produce a final, high-quality output much faster than decoding the large model alone. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly accelerate the inference speed of large language models while maintaining or improving output quality.

Perde arkasında süreç şu şekilde işler:

  • Taslağı oluşturan, N jetonun otomatik olarak oluşturulmasını önerir.
  • Hedef model, tek ileri geçişte tüm N jetonu doğrular.
  • Olasılığı yüksek olan taslak jetonlar kabul edilir.
  • Olasılığı düşük olan taslak jetonlar reddedilir.
  • Hedef model ileri geçiş yaptığından, kaç taslak jetonun kabul edildiğine veya reddedildiğine bakılmaksızın her zaman kendi başına 1 jeton oluşturur.

Taslak Jetonları

Taslak oluşturucu, hedef modelin doğrulanması için istediği sayıda jeton oluşturabilir. Ancak hedef model, belirli jetonları yine de reddedebilir. Bu durumda, bundan sonraki tüm jetonlar yoksayılır.

png

Bu nedenle, taslak oluşturulan jeton sayısı için çeşitli değerler kullanırken bu değerlerin ne gibi sonuçlar doğuracağını bilmek önemlidir.

Daha fazla taslak jetonu

Çok sayıda jeton taslağı oluşturduğunuzda (örneğin 15) tüm jetonların kabul edilmeme olasılığı yüksektir. Bu nedenle, bilgi işlem kaynaklarının boşa harcanma potansiyeli daha yüksektir. Bununla birlikte, kabul oranı yüksek olduğunda çıkarımı hızlandırma eğilimi vardır.

png

Daha az taslak jetonu

Daha az jeton tasarladığınızda, başlangıç istemine konum olarak daha yakın olan jetonlar daha doğru olduğundan kabul oranı daha yüksek olur. Ancak yalnızca birkaç jeton taslak olarak hazırlandığından, daha hızlı bir taslak oluşturma modelinden elde edeceğiniz hızlanma azalır.

png

Neyse ki transformers'da kullanım alanınız için en iyi değerlerle deneme yapmanız gerekmez. Çünkü num_assistant_tokens_schedule değerini "heuristic" olarak ayarlayabilirsiniz. Bu ayar, çalışma zamanında taslak oluşturulan jeton sayısını otomatik olarak uyarlar:

  • Tüm jetonlar kabul edildi: Taslak oluşturucu, istem konusunda oldukça doğru olduğundan taslak oluşturulacak jeton sayısını 2 artırın. Taslak oluşturulan jeton sayısının artırılması, bu jetonlar da kabul edilirse hızlanmaya neden olabilir.
  • Reddedilen jetonlar: Jeton reddedilirse taslak oluşturulacak jeton sayısını 1 azaltın. Hedef model jetonların çoğunu reddetmeye devam ederse jeton sayısını azaltmak, çok fazla taslağın boşa gitmesini önler.

Aynı şekilde, taslak oluşturucuda num_assistant_tokens değerini aşağıdaki gibi güncelleyerek taslak jeton sayısını da güncelleyebilirsiniz:

# Update how many draft tokens are generated at the start of inference
assistant_model.generation_config.num_assistant_tokens = 4

# Update how the number of draft tokens are updated ("heuristic" for a dynamic schedule and "constant" for a constant schedule)
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"

# Run with MTP
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then verified by a larger, more accurate model to quickly produce a high-quality output. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly speed up the inference process of large language models while maintaining or improving output quality.