Gemma 4 Dự đoán nhiều mã thông báo (MTP) bằng Hugging Face Transformers

Xem trên ai.google.dev Chạy trong Google Colab Chạy trong Kaggle Mở trong Vertex AI Xem nguồn trên GitHub

Để cải thiện tốc độ suy luận của các mô hình Gemma 4, một loạt mô hình "soạn thảo" tự hồi quy mới đã được phát hành cùng với dòng sản phẩm chính. Thay vì chỉ dựa vào các mô hình Gemma 4 chính (được gọi là mô hình "mục tiêu"), mô hình nháp dự đoán một số mã thông báo tự hồi quy trong thời gian mà mô hình mục tiêu chỉ xử lý một mã thông báo. Kỹ thuật này còn được gọi là giải mã suy đoán.

Sau khi người soạn thảo dự đoán nhiều mã thông báo bản nháp, giờ đây, mô hình mục tiêu chỉ cần xác minh những mã thông báo bản nháp được đề xuất đó. Quá trình xác minh được thực hiện song song, nhờ đó tăng tốc đáng kể quá trình suy luận. Điều này giúp giảm số lượng lượt truyền xuôi mà mô hình mục tiêu phải thực hiện cho mỗi mã thông báo. Vì trình soạn thảo của chúng tôi tạo ra một chuỗi mã thông báo để xác minh, nên chúng tôi gọi đó là phần dự đoán nhiều mã thông báo (MTP).

png

Các mô hình nháp được phát hành cho họ Gemma 4 có kích thước nhỏ và có một số điểm cải tiến để nâng cao chất lượng của các mã thông báo nháp và tăng tốc suy luận hơn nữa, chẳng hạn như sử dụng các lượt kích hoạt mô hình mục tiêu và bộ nhớ đệm khoá-giá trị để đưa ra dự đoán chính xác hơn.

Những điểm cải tiến này giúp tăng tốc độ giải mã đáng kể trong khi vẫn đảm bảo chất lượng tương đương, giúp các điểm kiểm tra này trở nên hoàn hảo cho các ứng dụng có độ trễ thấp và ứng dụng trên thiết bị.

Cài đặt các gói Python

Cài đặt các thư viện Hugging Face cần thiết để chạy mô hình Gemma 4 và mô hình trợ lý Gemma 4.

# Install PyTorch & other libraries
pip install torch accelerate

# Install the transformers library
pip install transformers

Tải các mô hình

Đối với mỗi mô hình mục tiêu (một trong những mô hình chính trong mô hình Gemma 4), có một trợ lý giúp tăng tốc suy luận. Do đó, bạn sẽ tải 2 mô hình:

  • Mục tiêu (ví dụ: google/gemma-4-E2B-it): Mô hình mục tiêu đầy đủ của Gemma 4
  • Drafter (ví dụ: google/gemma-4-E2B-it-assistant): Drafter MTP 4 lớp có trọng lượng nhẹ, đề xuất các mã thông báo đề xuất

Xin lưu ý rằng trình soạn thảo thường được gọi là trợ lý vì mô hình này giúp mô hình lớn hơn chọn những mã thông báo cần dự đoán.

Sử dụng các thư viện transformers để tạo một phiên bản của processormodel bằng cách sử dụng các lớp AutoProcessorAutoModelForCausalLM như trong ví dụ mã sau:

TARGET_MODEL_ID = "google/gemma-4-E2B-it" # @param ["google/gemma-4-E2B-it","google/gemma-4-E4B-it", "google/gemma-4-31B-it", "google/gemma-4-26B-A4B-it"]
ASSISTANT_MODEL_ID = TARGET_MODEL_ID + "-assistant"
import torch
from transformers import AutoProcessor, AutoModelForCausalLM

# Target Model
processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
target_model = AutoModelForCausalLM.from_pretrained(
    TARGET_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Assistant Model (the drafter)
assistant_model = AutoModelForCausalLM.from_pretrained(
    ASSISTANT_MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
[transformers] `torch_dtype` is deprecated! Use `dtype` instead!
Loading weights:   0%|          | 0/1951 [00:00<?, ?it/s]
Loading weights:   0%|          | 0/50 [00:00<?, ?it/s]

Gemma 4 có Trợ lý

Rất may là việc sử dụng một trợ lý trong transformers khá đơn giản và bạn chỉ cần truyền mô hình trợ lý vào hàm model.generate:

# Process inputs with the `target_model`
messages = [
    {
        "role": "user",
        "content": "Explain the concepts of speculative decoding and MTP in 3 sentences."
    }
]
input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=input_text, return_tensors="pt").to(target_model.device)

# `assistant_model=assistant_model` is all you need to enable MTP!
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then quickly verified by a larger, more accurate model to produce a final, high-quality output much faster than decoding the large model alone. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly accelerate the inference speed of large language models while maintaining or improving output quality.

Quy trình này diễn ra như sau:

  • Người soạn thảo đề xuất N mã thông báo được tạo tự động hồi quy
  • Mô hình đích xác minh tất cả N mã thông báo trong một lượt truyền xuôi
  • Các mã thông báo được tạo nháp có xác suất cao sẽ được chấp nhận
  • Các mã thông báo được tạo có xác suất thấp sẽ bị từ chối
  • Vì mô hình mục tiêu thực hiện một lượt truyền xuôi, nên mô hình này sẽ luôn tự tạo ra 1 mã thông báo, bất kể có bao nhiêu mã thông báo nháp được chấp nhận hay bị từ chối

Mã thông báo nháp

Người soạn thảo có thể tạo bất kỳ số lượng mã thông báo nào cho mô hình mục tiêu để xác minh. Tuy nhiên, mô hình mục tiêu vẫn có thể chọn từ chối một số mã thông báo nhất định. Khi đó, tất cả mã thông báo sau mã thông báo đó đều bị bỏ qua.

png

Do đó, bạn cần biết sự đánh đổi khi sử dụng nhiều giá trị cho số lượng mã thông báo được tạo nháp.

Nhiều mã thông báo bản nháp hơn

Khi bạn tạo nhiều mã thông báo (ví dụ: 15), thì có nhiều khả năng không phải tất cả mã thông báo đều được chấp nhận. Do đó, có nhiều khả năng bạn sẽ lãng phí tài nguyên điện toán. Ngược lại, nó có xu hướng tăng tốc suy luận khi tỷ lệ chấp nhận cao.

png

Ít mã thông báo nháp hơn

Khi bạn tạo ít mã thông báo hơn, tỷ lệ chấp nhận có xu hướng cao hơn vì những mã thông báo có vị trí gần với câu lệnh ban đầu sẽ chính xác hơn. Tuy nhiên, vì chỉ có một vài mã thông báo được tạo, nên tốc độ mà bạn nhận được từ một mô hình tạo bản nháp nhanh hơn sẽ giảm.

png

Rất may là bạn không cần thử nghiệm các giá trị tốt nhất cho trường hợp sử dụng của mình trong transformers vì bạn có thể đặt num_assistant_tokens_schedule thành "heuristic" (dựa trên kinh nghiệm) để tự động điều chỉnh số lượng mã thông báo được tạo nháp trong thời gian chạy:

  • Chấp nhận tất cả mã thông báo – Tăng số lượng mã thông báo cần tạo bản nháp thêm 2 vì người tạo bản nháp khá chính xác đối với câu lệnh. Việc tăng số lượng mã thông báo được tạo có thể giúp tăng tốc độ nếu những mã thông báo đó cũng được chấp nhận.
  • Mọi mã thông báo bị từ chối – Nếu có mã thông báo bị từ chối, hãy giảm số lượng mã thông báo cần tạo bản nháp đi 1. Việc giảm số lượng mã thông báo sẽ giúp bạn không lãng phí quá nhiều mã thông báo nháp nếu mô hình mục tiêu tiếp tục từ chối hầu hết các mã thông báo.

Tương tự, bạn có thể cập nhật số lượng mã thông báo nháp bằng cách cập nhật num_assistant_tokens trong drafter như sau:

# Update how many draft tokens are generated at the start of inference
assistant_model.generation_config.num_assistant_tokens = 4

# Update how the number of draft tokens are updated ("heuristic" for a dynamic schedule and "constant" for a constant schedule)
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"

# Run with MTP
outputs = target_model.generate(
    **inputs,
    assistant_model=assistant_model,
    max_new_tokens=256,
    do_sample=False,
)

# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then verified by a larger, more accurate model to quickly produce a high-quality output. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly speed up the inference process of large language models while maintaining or improving output quality.