תחזית של כמה טוקנים (MTP) של Gemma 4 באמצעות Hugging Face Transformers

לצפייה ב-ai.google.dev הרצה ב-Google Colab הרצה ב-Kaggle פתיחה ב-Vertex AI צפייה במקור ב-GitHub

כדי לשפר את מהירות ההיקש של מודלי Gemma 4, השקנו סדרה חדשה של מודלים אוטורגרסיביים מסוג 'טיוטה' לצד המודלים העיקריים. במקום להסתמך רק על מודלי Gemma 4 העיקריים (שנקראים מודלי 'היעד'), מודל הטיוטה חוזה כמה טוקנים באופן אוטורגרסיבי בזמן שלוקח למודל היעד לעבד רק טוקן אחד. הטכניקה הזו נקראת גם פענוח ספקולטיבי.

אחרי שהמודל ליצירת טיוטה מנבא כמה טוקנים של טיוטה, מודל היעד צריך רק לאמת את הטוקנים המוצעים האלה. האימות מתבצע במקביל, ולכן תהליך ההסקה מהיר יותר. הוא מצמצם את מספר ההעברות קדימה שהמודל צריך לבצע עבור כל טוקן. מכיוון שהמנסח שלנו יוצר רצף של טוקנים לאימות, אנחנו מתייחסים אליו כאל ראש Multi-Token Prediction (MTP).

png

המודלים של Gemma 4 שפורסמו הם קטנים, והם כוללים כמה שיפורים שמטרתם לשפר את האיכות של הטוקנים שנוצרו ולזרז עוד יותר את ההסקה, כמו שימוש בהפעלות של מודל היעד ובמטמון KV כדי לקבל תחזיות טובות יותר.

השיפורים האלה מובילים להאצת פענוח משמעותית, תוך שמירה על איכות דומה. לכן, נקודות הבדיקה האלה מתאימות במיוחד לאפליקציות עם זמן אחזור נמוך ולאפליקציות במכשיר.

התקנת חבילות Python

מתקינים את הספריות של Hugging Face שנדרשות להרצת מודל Gemma 4 ומודל Gemma 4 Assistant.

# Install PyTorch & other libraries
pip install torch accelerate

# Install the transformers library
pip install transformers

טעינת המודלים

לכל מודל יעד (אחד מהמודלים העיקריים במודל Gemma 4) יש עוזר שמאיץ את ההסקה. לכן, תטענו שני מודלים:

  • Target (למשל, google/gemma-4-E2B-it): מודל היעד המלא של Gemma 4
  • Drafter (למשל, google/gemma-4-E2B-it-assistant): כלי קל משקל ליצירת טיוטות של MTP עם 4 שכבות, שמציע טוקנים מתאימים

שימו לב שהמנסח נקרא לעיתים קרובות עוזר, כי המודל עוזר למודל הגדול יותר לבחור אילו טוקנים לחזות.

משתמשים בספריות transformers כדי ליצור מופע של processor ושל model באמצעות המחלקות AutoProcessor ו-AutoModelForCausalLM, כמו בדוגמת הקוד הבאה:

TARGET_MODEL_ID = "google/gemma-4-E2B-it" # @param ["google/gemma-4-E2B-it","google/gemma-4-E4B-it", "google/gemma-4-31B-it", "google/gemma-4-26B-A4B-it"]
ASSISTANT_MODEL_ID = TARGET_MODEL_ID + "-assistant"
import torch
from transformers import AutoProcessor, AutoModelForCausalLM

# Target Model
processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
target_model = AutoModelForCausalLM.from_pretrained(
    TARGET_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Assistant Model (the drafter)
assistant_model = AutoModelForCausalLM.from_pretrained(
    ASSISTANT_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
[transformers] `torch_dtype` is deprecated! Use `dtype` instead!
Loading weights:   0%|          | 0/1951 [00:00<?, ?it/s]
Loading weights:   0%|          | 0/50 [00:00<?, ?it/s]

‫Gemma 4 עם Assistant

למזלנו, השימוש בעוזר ב-transformers הוא פשוט למדי, ונדרש רק להעביר את מודל העוזר לפונקציה model.generate:

# Process inputs with the `target_model`
messages = [
    {
        "role": "user",
        "content": "Explain the concepts of speculative decoding and MTP in 3 sentences."
    }
]
input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=input_text, return_tensors="pt").to(target_model.device)

# `assistant_model=assistant_model` is all you need to enable MTP!
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then quickly verified by a larger, more accurate model to produce a final, high-quality output much faster than decoding the large model alone. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly accelerate the inference speed of large language models while maintaining or improving output quality.

מאחורי הקלעים, התהליך הוא כזה:

  • הטיוטה מציעה N טוקנים שנוצרו באופן אוטומטי
  • מודל היעד מאמת את כל N הטוקנים במעבר קדימה אחד
  • מתקבלים טוקנים שנוסחו עם הסתברויות גבוהות
  • טוקנים שנכתבו עם הסתברויות נמוכות נדחים
  • מכיוון שמודל היעד מבצע העברה קדימה, הוא תמיד ייצור טוקן אחד בעצמו, לא משנה כמה טוקנים של טיוטה התקבלו או נדחו.

טוקנים של טיוטה

היוצר יכול ליצור כל כמות של אסימונים עבור מודל היעד כדי לאמת אותו. עם זאת, מודל היעד עדיין יכול לדחות טוקנים מסוימים. אם הוא מופיע, המערכת מתעלמת מכל האסימונים שמופיעים אחריו.

png

לכן חשוב להבין את ההשפעה של שימוש בערכים שונים למספר הטוקנים בטיוטה.

אסימונים נוספים של טיוטות

כשמנסחים הרבה טוקנים (לדוגמה, 15), יש סיכוי גבוה שלא כל הטוקנים יתקבלו. לכן, יש פוטנציאל גבוה יותר לבזבוז של משאבי מחשוב. לעומת זאת, יש לו נטייה להאיץ את ההסקה כששיעור הקבלה גבוה.

png

פחות טוקנים של טיוטות

כשמנסחים פחות טוקנים, שיעור הקבלה נוטה להיות גבוה יותר, כי טוקנים שקרובים יותר במיקום להנחיה הראשונית הם מדויקים יותר. עם זאת, מכיוון שרק כמה טוקנים נכתבים, השיפור במהירות שמתקבל ממודל מהיר יותר של כתיבת טקסטים הוא מוגבל.

png

למזלכם, אתם לא צריכים להתנסות עם הערכים הכי טובים לתרחיש השימוש שלכם ב-transformers, כי אתם יכולים להגדיר את num_assistant_tokens_schedule ל-heuristic (היוריסטיקה), וכך מספר הטוקנים שנוסחו יותאם אוטומטית בזמן הריצה:

  • כל הטוקנים התקבלו – כדאי להגדיל את מספר הטוקנים לטיוטה ב-2, כי הטיוטה די מדויקת להנחיה. הגדלת מספר הטוקנים שנוסחו עשויה להוביל להאצה אם הטוקנים האלה יתקבלו.
  • Any tokens rejected (אסימונים שנדחו) – אם נדחו אסימונים, צריך להקטין את מספר האסימונים לטיוטה ב-1. הפחתת מספר הטוקנים מבטיחה שלא יבוזבזו יותר מדי טיוטות אם מודל היעד ימשיך לדחות את רוב הטוקנים.

באופן דומה, אפשר לעדכן את מספר הטוקנים של הטיוטה על ידי עדכון num_assistant_tokens בטיוטה באופן הבא:

# Update how many draft tokens are generated at the start of inference
assistant_model.generation_config.num_assistant_tokens = 4

# Update how the number of draft tokens are updated ("heuristic" for a dynamic schedule and "constant" for a constant schedule)
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"

# Run with MTP
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then verified by a larger, more accurate model to quickly produce a high-quality output. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly speed up the inference process of large language models while maintaining or improving output quality.