|
|
Chạy trong Google Colab
|
|
|
Xem nguồn trên GitHub
|
Để cải thiện tốc độ suy luận của các mô hình Gemma 4, một loạt mô hình "soạn thảo" tự hồi quy mới đã được phát hành cùng với dòng sản phẩm chính. Thay vì chỉ dựa vào các mô hình Gemma 4 chính (được gọi là mô hình "mục tiêu"), mô hình nháp dự đoán một số mã thông báo tự hồi quy trong thời gian mà mô hình mục tiêu chỉ xử lý một mã thông báo. Kỹ thuật này còn được gọi là giải mã suy đoán.
Sau khi người soạn thảo dự đoán nhiều mã thông báo bản nháp, giờ đây, mô hình mục tiêu chỉ cần xác minh những mã thông báo bản nháp được đề xuất đó. Quá trình xác minh được thực hiện song song, nhờ đó tăng tốc đáng kể quá trình suy luận. Điều này giúp giảm số lượng lượt truyền xuôi mà mô hình mục tiêu phải thực hiện cho mỗi mã thông báo. Vì trình soạn thảo của chúng tôi tạo ra một chuỗi mã thông báo để xác minh, nên chúng tôi gọi đó là phần dự đoán nhiều mã thông báo (MTP).

Các mô hình nháp được phát hành cho họ Gemma 4 có kích thước nhỏ và có một số điểm cải tiến để nâng cao chất lượng của các mã thông báo nháp và tăng tốc suy luận hơn nữa, chẳng hạn như sử dụng các lượt kích hoạt mô hình mục tiêu và bộ nhớ đệm khoá-giá trị để đưa ra dự đoán chính xác hơn.
Những điểm cải tiến này giúp tăng tốc độ giải mã đáng kể trong khi vẫn đảm bảo chất lượng tương đương, giúp các điểm kiểm tra này trở nên hoàn hảo cho các ứng dụng có độ trễ thấp và ứng dụng trên thiết bị.
Cài đặt các gói Python
Cài đặt các thư viện Hugging Face cần thiết để chạy mô hình Gemma 4 và mô hình trợ lý Gemma 4.
# Install PyTorch & other librariespip install torch accelerate# Install the transformers librarypip install transformers
Tải các mô hình
Đối với mỗi mô hình mục tiêu (một trong những mô hình chính trong mô hình Gemma 4), có một trợ lý giúp tăng tốc suy luận. Do đó, bạn sẽ tải 2 mô hình:
- Mục tiêu (ví dụ:
google/gemma-4-E2B-it): Mô hình mục tiêu đầy đủ của Gemma 4 - Drafter (ví dụ:
google/gemma-4-E2B-it-assistant): Drafter MTP 4 lớp có trọng lượng nhẹ, đề xuất các mã thông báo đề xuất
Xin lưu ý rằng trình soạn thảo thường được gọi là trợ lý vì mô hình này giúp mô hình lớn hơn chọn những mã thông báo cần dự đoán.
Sử dụng các thư viện transformers để tạo một phiên bản của processor và model bằng cách sử dụng các lớp AutoProcessor và AutoModelForCausalLM như trong ví dụ mã sau:
TARGET_MODEL_ID = "google/gemma-4-E2B-it" # @param ["google/gemma-4-E2B-it","google/gemma-4-E4B-it", "google/gemma-4-31B-it", "google/gemma-4-26B-A4B-it"]
ASSISTANT_MODEL_ID = TARGET_MODEL_ID + "-assistant"
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
# Target Model
processor = AutoProcessor.from_pretrained(TARGET_MODEL_ID)
target_model = AutoModelForCausalLM.from_pretrained(
TARGET_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# Assistant Model (the drafter)
assistant_model = AutoModelForCausalLM.from_pretrained(
ASSISTANT_MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
[transformers] `torch_dtype` is deprecated! Use `dtype` instead! Loading weights: 0%| | 0/1951 [00:00<?, ?it/s] Loading weights: 0%| | 0/50 [00:00<?, ?it/s]
Gemma 4 có Trợ lý
Rất may là việc sử dụng một trợ lý trong transformers khá đơn giản và bạn chỉ cần truyền mô hình trợ lý vào hàm model.generate:
# Process inputs with the `target_model`
messages = [
{
"role": "user",
"content": "Explain the concepts of speculative decoding and MTP in 3 sentences."
}
]
input_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=input_text, return_tensors="pt").to(target_model.device)
# `assistant_model=assistant_model` is all you need to enable MTP!
outputs = target_model.generate(
**inputs,
assistant_model=assistant_model,
max_new_tokens=256,
do_sample=False,
)
# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then quickly verified by a larger, more accurate model to produce a final, high-quality output much faster than decoding the large model alone. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly accelerate the inference speed of large language models while maintaining or improving output quality.
Quy trình này diễn ra như sau:
- Người soạn thảo đề xuất N mã thông báo được tạo tự động hồi quy
- Mô hình đích xác minh tất cả N mã thông báo trong một lượt truyền xuôi
- Các mã thông báo được tạo nháp có xác suất cao sẽ được chấp nhận
- Các mã thông báo được tạo có xác suất thấp sẽ bị từ chối
- Vì mô hình mục tiêu thực hiện một lượt truyền xuôi, nên mô hình này sẽ luôn tự tạo ra 1 mã thông báo, bất kể có bao nhiêu mã thông báo nháp được chấp nhận hay bị từ chối
Mã thông báo nháp
Người soạn thảo có thể tạo bất kỳ số lượng mã thông báo nào cho mô hình mục tiêu để xác minh. Tuy nhiên, mô hình mục tiêu vẫn có thể chọn từ chối một số mã thông báo nhất định. Khi đó, tất cả mã thông báo sau mã thông báo đó đều bị bỏ qua.

Do đó, bạn cần biết sự đánh đổi khi sử dụng nhiều giá trị cho số lượng mã thông báo được tạo nháp.
Nhiều mã thông báo bản nháp hơn
Khi bạn tạo nhiều mã thông báo (ví dụ: 15), thì có nhiều khả năng không phải tất cả mã thông báo đều được chấp nhận. Do đó, có nhiều khả năng bạn sẽ lãng phí tài nguyên điện toán. Ngược lại, nó có xu hướng tăng tốc suy luận khi tỷ lệ chấp nhận cao.

Ít mã thông báo nháp hơn
Khi bạn tạo ít mã thông báo hơn, tỷ lệ chấp nhận có xu hướng cao hơn vì những mã thông báo có vị trí gần với câu lệnh ban đầu sẽ chính xác hơn. Tuy nhiên, vì chỉ có một vài mã thông báo được tạo, nên tốc độ mà bạn nhận được từ một mô hình tạo bản nháp nhanh hơn sẽ giảm.

Rất may là bạn không cần thử nghiệm các giá trị tốt nhất cho trường hợp sử dụng của mình trong transformers vì bạn có thể đặt num_assistant_tokens_schedule thành "heuristic" (dựa trên kinh nghiệm) để tự động điều chỉnh số lượng mã thông báo được tạo nháp trong thời gian chạy:
- Chấp nhận tất cả mã thông báo – Tăng số lượng mã thông báo cần tạo bản nháp thêm 2 vì người tạo bản nháp khá chính xác đối với câu lệnh. Việc tăng số lượng mã thông báo được tạo có thể giúp tăng tốc độ nếu những mã thông báo đó cũng được chấp nhận.
- Mọi mã thông báo bị từ chối – Nếu có mã thông báo bị từ chối, hãy giảm số lượng mã thông báo cần tạo bản nháp đi 1. Việc giảm số lượng mã thông báo sẽ giúp bạn không lãng phí quá nhiều mã thông báo nháp nếu mô hình mục tiêu tiếp tục từ chối hầu hết các mã thông báo.
Tương tự, bạn có thể cập nhật số lượng mã thông báo nháp bằng cách cập nhật num_assistant_tokens trong drafter như sau:
# Update how many draft tokens are generated at the start of inference
assistant_model.generation_config.num_assistant_tokens = 4
# Update how the number of draft tokens are updated ("heuristic" for a dynamic schedule and "constant" for a constant schedule)
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"
# Run with MTP
outputs = target_model.generate(
**inputs,
assistant_model=assistant_model,
max_new_tokens=256,
do_sample=False,
)
# Decode the response into text
response = processor.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
**Speculative decoding** is a technique where a smaller, faster language model (the "draft model") generates several candidate tokens, which are then verified by a larger, more accurate model to quickly produce a high-quality output. **MTP (Multi-Task Prediction)** involves training a single model to perform multiple related tasks simultaneously, allowing it to leverage shared knowledge across different objectives. Together, these methods aim to significantly speed up the inference process of large language models while maintaining or improving output quality.
Chạy trong Google Colab
Xem nguồn trên GitHub