پیش‌بینی چند توکنی (MTP) جما ۴ با استفاده از مبدل‌های چهره در آغوش گرفته

مشاهده در ai.google.dev در گوگل کولب اجرا کنید دویدن در کاگل باز کردن در Vertex AI مشاهده منبع در گیت‌هاب

برای بهبود سرعت استنتاج مدل‌های Gemma 4، سری جدیدی از مدل‌های «پیش‌نویس» خودهمبسته در کنار مدل‌های اصلی منتشر شده است. مدل پیش‌نویس به جای تکیه صرف بر مدل‌های اولیه Gemma 4 (که به عنوان مدل‌های «هدف» شناخته می‌شوند)، چندین توکن را به صورت خودهمبسته در زمانی که مدل هدف فقط یکی از آنها را پردازش می‌کند، پیش‌بینی می‌کند. این تکنیک همچنین به عنوان رمزگشایی حدسی شناخته می‌شود.

پس از اینکه طراح چندین توکن پیش‌نویس را پیش‌بینی کرد، مدل هدف اکنون فقط باید آن توکن‌های پیش‌نویس پیشنهادی را تأیید کند. تأیید به صورت موازی انجام می‌شود و در نتیجه سرعت استنتاج را به شدت افزایش می‌دهد. این کار تعداد پاس‌های رو به جلویی را که مدل هدف باید برای هر توکن انجام دهد، کاهش می‌دهد. از آنجا که طراح ما توالی‌ای از توکن‌ها را برای تأیید تولید می‌کند، ما به آن هد پیش‌بینی چند توکنی (MTP) می‌گوییم.

png

مدل‌های پیش‌نویس منتشر شده برای خانواده Gemma 4 کوچک هستند و چندین پیشرفت را برای بهبود کیفیت توکن‌های پیش‌نویس شده و افزایش سرعت استنتاج، مانند استفاده از فعال‌سازی‌های مدل هدف و KV-cache برای دستیابی به پیش‌بینی‌های بهتر، معرفی می‌کنند.

این پیشرفت‌ها منجر به افزایش قابل توجه سرعت رمزگشایی می‌شوند و در عین حال کیفیت مشابهی را تضمین می‌کنند، که این نقاط کنترل را برای برنامه‌های کاربردی با تأخیر کم و روی دستگاه ایده‌آل می‌کند.

نصب بسته‌های پایتون

کتابخانه‌های Hugging Face مورد نیاز برای اجرای Gemma 4 و مدل کمکی Gemma 4 را نصب کنید.

# Install PyTorch & other libraries
pip install torch accelerate

# Install the transformers library
pip install transformers

مدل‌ها را بارگذاری کنید

برای هر مدل هدف (یکی از مدل‌های اصلی در مدل Gemma 4)، یک دستیار وجود دارد که به سرعت بخشیدن به استنتاج کمک می‌کند. به این ترتیب، شما دو مدل را بارگذاری خواهید کرد:

  • هدف (مثلاً google/gemma-4-E2B-it ): مدل کامل هدف Gemma 4
  • طراح (مثلاً google/gemma-4-E2B-it-assistant ): طراح سبک وزن MTP 4 لایه که توکن‌های کاندید را پیشنهاد می‌دهد.

توجه داشته باشید که طراح مدل اغلب به عنوان دستیار شناخته می‌شود، زیرا این مدل به مدل بزرگتر در انتخاب توکن‌هایی که باید پیش‌بینی شوند، کمک می‌کند.

از کتابخانه‌های transformers برای ایجاد یک نمونه از یک processor و model با استفاده از کلاس‌های AutoProcessor و AutoModelForCausalLM ، همانطور که در مثال کد زیر نشان داده شده است، استفاده کنید:

TARGET_MODEL_ID = "google/gemma-4-E2B-it" # @param ["google/gemma-4-E2B-it","google/gemma-4-E4B-it", "google/gemma-4-31B-it", "google/gemma-4-26B-A4B-it"]
ASSISTANT_MODEL_ID = TARGET_MODEL_ID + "-assistant"
import torch
from transformers import AutoProcessor, AutoModelForCausalLM

# Target Model
processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
target_model = AutoModelForCausalLM.from_pretrained(
    TARGET_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Assistant Model (the drafter)
assistant_model = AutoModelForCausalLM.from_pretrained(
    ASSISTANT_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
[transformers] `torch_dtype` is deprecated! Use `dtype` instead!
Loading weights:   0%|          | 0/1951 [00:00<?, ?it/s]
Loading weights:   0%|          | 0/50 [00:00<?, ?it/s]

جما ۴ به همراه دستیار

خوشبختانه استفاده از یک دستیار در transformers کاملاً سرراست است و مستلزم آن است که مدل دستیار را به تابع model.generate ارسال کنید:

# Process inputs with the `target_model`
messages = [
    {
        "role": "user",
        "content": "Explain the concepts of speculative decoding and MTP in 3 sentences."
    }
]
input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=input_text, return_tensors="pt").to(target_model.device)

# `assistant_model=assistant_model` is all you need to enable MTP!
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then quickly verified by a larger, more accurate model to produce a final, high-quality output much faster than decoding the large model alone. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly accelerate the inference speed of large language models while maintaining or improving output quality.

در زیر کاپوت، روند به شرح زیر است:

  • طراح، N توکن تولید شده به صورت خودرگرسیونی را پیشنهاد می‌دهد.
  • مدل هدف، تمام N توکن را در یک مسیر رو به جلو تأیید می‌کند.
  • توکن‌های طراحی‌شده با احتمال برد بالا پذیرفته می‌شوند
  • توکن‌های طراحی‌شده با احتمال پایین رد می‌شوند
  • از آنجایی که مدل هدف یک انتقال رو به جلو انجام می‌دهد، صرف نظر از اینکه چند توکن پیش‌نویس پذیرفته یا رد شده‌اند، همیشه به تنهایی ۱ توکن تولید می‌کند.

توکن‌های پیش‌نویس

طراح می‌تواند هر تعداد توکن برای تأیید مدل هدف تولید کند. با این حال، مدل هدف همچنان می‌تواند توکن‌های خاصی را رد کند. در این صورت، تمام توکن‌های پس از آن نادیده گرفته می‌شوند.

png

به همین دلیل، دانستن مصالحه هنگام استفاده از مقادیر مختلف برای تعداد توکن‌های استخراج‌شده، مهم است.

توکن‌های پیش‌نویس بیشتر

وقتی تعداد زیادی توکن (مثلاً ۱۵ توکن) ایجاد می‌کنید، احتمال زیادی وجود دارد که همه توکن‌ها پذیرفته نشوند. به این ترتیب، پتانسیل بیشتری برای هدر رفتن محاسبات وجود دارد. در مقابل، وقتی نرخ پذیرش بالا باشد، تمایل به سرعت بخشیدن به استنتاج وجود دارد.

png

توکن‌های پیش‌نویس کمتر

وقتی تعداد کمتری توکن تهیه می‌کنید، نرخ پذیرش معمولاً بالاتر می‌رود، زیرا توکن‌هایی که به درخواست اولیه نزدیک‌تر هستند، دقیق‌ترند. با این حال، از آنجایی که فقط تعداد کمی توکن تهیه می‌شود، سرعتی که از یک مدل تهیه‌کننده سریع‌تر به دست می‌آورید، کاهش می‌یابد.

png

خوشبختانه، لازم نیست بهترین مقادیر را برای مورد استفاده خود در transformers آزمایش کنید، زیرا می‌توانید num_assistant_tokens_schedule را روی "heuristic" تنظیم کنید که به طور خودکار تعداد توکن‌های پیش‌نویس شده را در زمان اجرا تطبیق می‌دهد:

  • همه توکن‌ها پذیرفته می‌شوند -- تعداد توکن‌ها برای پیش‌نویس را ۲ واحد افزایش دهید، زیرا طراح برای این سوال کاملاً دقیق است. افزایش تعداد توکن‌های پیش‌نویس شده ممکن است منجر به افزایش سرعت شود، اگر آن توکن‌ها نیز پذیرفته شوند.
  • هر توکنی رد شود -- اگر هر توکنی رد شود، تعداد توکن‌های پیش‌نویس را ۱ واحد کاهش دهید. کاهش تعداد توکن‌ها باعث می‌شود که اگر مدل هدف همچنان بیشتر توکن‌ها را رد کند، تعداد زیادی از توکن‌های پیش‌نویس هدر نرود.

به همین ترتیب، می‌توانید تعداد توکن‌های پیش‌نویس را با به‌روزرسانی num_assistant_tokens در پیش‌نویس به صورت زیر به‌روزرسانی کنید:

# Update how many draft tokens are generated at the start of inference
assistant_model.generation_config.num_assistant_tokens = 4

# Update how the number of draft tokens are updated ("heuristic" for a dynamic schedule and "constant" for a constant schedule)
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"

# Run with MTP
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then verified by a larger, more accurate model to quickly produce a high-quality output. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly speed up the inference process of large language models while maintaining or improving output quality.