হাগিং ফেস ট্রান্সফর্মার ব্যবহার করে জেমা ৪ মাল্টি-টোকেন প্রেডিকশন (এমটিপি)

ai.google.dev-এ দেখুন গুগল কোলাবে চালান Kaggle-এ চালান ভার্টেক্স এআই-তে খুলুন গিটহাবে উৎস দেখুন

জেমা ৪ মডেলগুলোর ইনফারেন্স গতি উন্নত করার জন্য, মূল মডেলগুলোর পাশাপাশি অটোরিগ্রেসিভ “ড্রাফটার” মডেলের একটি নতুন সিরিজ প্রকাশ করা হয়েছে। শুধুমাত্র প্রাথমিক জেমা ৪ মডেলগুলোর (যাকে “টার্গেট” মডেল বলা হয়) উপর নির্ভর না করে, ড্রাফট মডেলটি টার্গেট মডেলের একটি টোকেন প্রসেস করার সময়েই অটোরিগ্রেসিভভাবে বেশ কয়েকটি টোকেন প্রেডিক্ট করে। এই কৌশলটি স্পেকুলেটিভ ডিকোডিং নামেও পরিচিত।

ড্রাফটার একাধিক ড্রাফট টোকেন অনুমান করার পর, টার্গেট মডেলকে এখন শুধু সেই প্রস্তাবিত ড্রাফট টোকেনগুলো যাচাই করতে হয়। এই যাচাইকরণ প্রক্রিয়াটি সমান্তরালভাবে সম্পন্ন হয়, যার ফলে ইনফারেন্সের গতি ব্যাপকভাবে বেড়ে যায়। এটি প্রতিটি টোকেনের জন্য টার্গেট মডেলকে যে সংখ্যক ফরোয়ার্ড পাস করতে হয়, তা কমিয়ে দেয়। যেহেতু আমাদের ড্রাফটার যাচাইকরণের জন্য টোকেনের একটি ক্রম তৈরি করে, তাই আমরা এটিকে মাল্টি-টোকেন প্রেডিকশন (MTP) হেড বলে থাকি।

পিএনজি

জেমা ৪ ফ্যামিলির জন্য প্রকাশিত ড্রাফট মডেলগুলো আকারে ছোট এবং এতে ড্রাফট করা টোকেনগুলোর মান উন্নত করতে ও ইনফারেন্সের গতি আরও বাড়াতে বেশ কিছু বর্ধিত বৈশিষ্ট্য যোগ করা হয়েছে; যেমন, আরও ভালো প্রেডিকশন পাওয়ার জন্য টার্গেট মডেল অ্যাক্টিভেশন এবং কেভি-ক্যাশ ব্যবহার করা।

এই উন্নয়নগুলোর ফলে ডিকোডিংয়ের গতি উল্লেখযোগ্যভাবে বৃদ্ধি পায় এবং একই সাথে গুণমানও অক্ষুণ্ণ থাকে, যা এই চেকপয়েন্টগুলোকে স্বল্প-বিলম্বের এবং ডিভাইস-ভিত্তিক অ্যাপ্লিকেশনগুলোর জন্য আদর্শ করে তোলে।

পাইথন প্যাকেজ ইনস্টল করুন

Gemma 4 এবং Gemma 4 অ্যাসিস্ট্যান্ট মডেল চালানোর জন্য প্রয়োজনীয় Hugging Face লাইব্রেরিগুলো ইনস্টল করুন।

# Install PyTorch & other libraries
pip install torch accelerate

# Install the transformers library
pip install transformers

মডেলগুলো লোড করুন

প্রতিটি টার্গেট মডেলের (জেমা ৪ মডেলের প্রধান মডেলগুলোর মধ্যে একটি) জন্য একটি অ্যাসিস্ট্যান্ট রয়েছে যা ইনফারেন্সের গতি বাড়াতে সাহায্য করে। সেই অনুযায়ী, আপনাকে দুটি মডেল লোড করতে হবে:

  • টার্গেট (যেমন, google/gemma-4-E2B-it ): পূর্ণাঙ্গ জেমা ৪ টার্গেট মডেল
  • ড্রাফটার (যেমন, google/gemma-4-E2B-it-assistant ): হালকা ওজনের ৪-স্তর বিশিষ্ট এমটিপি ড্রাফটার যা সম্ভাব্য টোকেন প্রস্তাব করে।

উল্লেখ্য যে, খসড়া প্রস্তুতকারীকে প্রায়শই সহকারী হিসাবে উল্লেখ করা হয়, কারণ মডেলটি বৃহত্তর মডেলকে কোন টোকেনগুলি ভবিষ্যদ্বাণী করতে হবে তা বেছে নিতে সাহায্য করে।

নিম্নলিখিত কোড উদাহরণে দেখানো অনুযায়ী transformers লাইব্রেরি ব্যবহার করে AutoProcessor এবং AutoModelForCausalLM ক্লাসগুলোর সাহায্যে একটি processor এবং model ইনস্ট্যান্স তৈরি করুন:

TARGET_MODEL_ID = "google/gemma-4-E2B-it" # @param ["google/gemma-4-E2B-it","google/gemma-4-E4B-it", "google/gemma-4-31B-it", "google/gemma-4-26B-A4B-it"]
ASSISTANT_MODEL_ID = TARGET_MODEL_ID + "-assistant"
import torch
from transformers import AutoProcessor, AutoModelForCausalLM

# Target Model
processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
target_model = AutoModelForCausalLM.from_pretrained(
    TARGET_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Assistant Model (the drafter)
assistant_model = AutoModelForCausalLM.from_pretrained(
    ASSISTANT_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
[transformers] `torch_dtype` is deprecated! Use `dtype` instead!
Loading weights:   0%|          | 0/1951 [00:00<?, ?it/s]
Loading weights:   0%|          | 0/50 [00:00<?, ?it/s]

সহকারীর সাথে জেমা ৪

সৌভাগ্যবশত, transformers অ্যাসিস্ট্যান্ট ব্যবহার করা বেশ সহজ এবং এর জন্য আপনাকে model.generate ফাংশনে অ্যাসিস্ট্যান্ট মডেলটি পাস করতে হবে:

# Process inputs with the `target_model`
messages = [
    {
        "role": "user",
        "content": "Explain the concepts of speculative decoding and MTP in 3 sentences."
    }
]
input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=input_text, return_tensors="pt").to(target_model.device)

# `assistant_model=assistant_model` is all you need to enable MTP!
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then quickly verified by a larger, more accurate model to produce a final, high-quality output much faster than decoding the large model alone. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly accelerate the inference speed of large language models while maintaining or improving output quality.

অভ্যন্তরীণ প্রক্রিয়াটি নিম্নরূপ:

  • খসড়া প্রস্তুতকারী অটোরেগ্রেসিভভাবে উৎপন্ন N সংখ্যক টোকেন প্রস্তাব করেন।
  • টার্গেট মডেলটি একটি ফরোয়ার্ড পাসে সমস্ত N টোকেন যাচাই করে।
  • উচ্চ সম্ভাবনাসম্পন্ন খসড়া টোকেনগুলো গৃহীত হয়।
  • কম সম্ভাবনাযুক্ত খসড়া টোকেনগুলি বাতিল করা হয়।
  • যেহেতু টার্গেট মডেলটি একটি ফরোয়ার্ড পাস করে, তাই কতগুলো ড্রাফট করা টোকেন গৃহীত বা প্রত্যাখ্যাত হয়েছে তা নির্বিশেষে এটি সর্বদা নিজে থেকেই ১টি টোকেন তৈরি করবে।

খসড়া টোকেন

খসড়া প্রস্তুতকারী টার্গেট মডেলের যাচাই করার জন্য যেকোনো সংখ্যক টোকেন তৈরি করতে পারেন। তবে, টার্গেট মডেল চাইলে নির্দিষ্ট কিছু টোকেন প্রত্যাখ্যান করতে পারে। যখন এটি তা করে, তখন তার পরবর্তী সমস্ত টোকেন উপেক্ষা করা হয়।

পিএনজি

সেই হিসেবে, ড্রাফট করা টোকেনের সংখ্যার জন্য বিভিন্ন মান ব্যবহার করার ক্ষেত্রে সুবিধা-অসুবিধাগুলো জানা গুরুত্বপূর্ণ।

আরও খসড়া টোকেন

যখন আপনি অনেকগুলো টোকেন ড্রাফট করেন (উদাহরণস্বরূপ ১৫টি), তখন সব টোকেন গৃহীত না হওয়ার সম্ভাবনা অনেক বেশি থাকে। ফলে, কম্পিউট অপচয়ের সম্ভাবনাও বেড়ে যায়। এর বিপরীতে, গ্রহণের হার বেশি হলে এটি ইনফারেন্সের গতি বাড়িয়ে দেওয়ার প্রবণতা দেখায়।

পিএনজি

কম খসড়া টোকেন

যখন আপনি কম টোকেন ড্রাফট করেন, তখন গ্রহণের হার বেশি হওয়ার প্রবণতা থাকে, কারণ প্রাথমিক প্রম্পটের কাছাকাছি থাকা টোকেনগুলো বেশি নির্ভুল হয়। তবে, যেহেতু অল্প কয়েকটি টোকেন ড্রাফট করা হয়, তাই একটি দ্রুততর ড্রাফটার মডেল থেকে যে গতি বৃদ্ধি পাওয়া যেত, তা কমে যায়।

পিএনজি

সৌভাগ্যবশত, transformers আপনার ব্যবহারের ক্ষেত্রের জন্য সেরা মান নিয়ে পরীক্ষা-নিরীক্ষা করার প্রয়োজন নেই, কারণ আপনি num_assistant_tokens_schedule কে "heuristic"-এ সেট করতে পারেন, যা রানটাইমে স্বয়ংক্রিয়ভাবে ড্রাফট করা টোকেনের সংখ্যা নির্ধারণ করবে।

  • সমস্ত টোকেন গৃহীত হয়েছে -- যেহেতু খসড়া প্রস্তুতকারী নির্দেশনার ক্ষেত্রে বেশ নির্ভুল, তাই খসড়া করার জন্য টোকেনের সংখ্যা ২ বৃদ্ধি করুন। যদি খসড়া করা টোকেনগুলোও গৃহীত হয়, তবে এর সংখ্যা বৃদ্ধি করলে কাজের গতি বাড়তে পারে।
  • যেকোনো টোকেন প্রত্যাখ্যাত হলে -- যদি কোনো টোকেন প্রত্যাখ্যাত হয়, তাহলে ড্রাফট করার জন্য টোকেনের পরিমাণ ১ কমিয়ে দিন। টোকেনের সংখ্যা কমানোর ফলে, যদি টার্গেট মডেলটি বেশিরভাগ টোকেন প্রত্যাখ্যান করতে থাকে, তাহলেও ড্রাফট করা খুব বেশি টোকেন নষ্ট হয় না।

একইভাবে, আপনি ড্রাফটারে num_assistant_tokens আপডেট করার মাধ্যমে ড্রাফট টোকেনের সংখ্যাও আপডেট করতে পারেন, যেমনটা নিচে দেখানো হয়েছে:

# Update how many draft tokens are generated at the start of inference
assistant_model.generation_config.num_assistant_tokens = 4

# Update how the number of draft tokens are updated ("heuristic" for a dynamic schedule and "constant" for a constant schedule)
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"

# Run with MTP
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then verified by a larger, more accurate model to quickly produce a high-quality output. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly speed up the inference process of large language models while maintaining or improving output quality.