Hướng dẫn này cho bạn biết cách chạy Gemma bằng khung PyTorch, bao gồm cả cách sử dụng dữ liệu hình ảnh để nhắc các mô hình Gemma phiên bản 3 trở lên. Để biết thêm thông tin chi tiết về cách triển khai Gemma PyTorch, hãy xem README của kho lưu trữ dự án.
Thiết lập
Các phần sau giải thích cách thiết lập môi trường phát triển, bao gồm cả cách truy cập vào các mô hình Gemma để tải xuống từ Kaggle, thiết lập biến xác thực, cài đặt phần phụ thuộc và nhập gói.
Yêu cầu hệ thống
Thư viện Gemma Pytorch này yêu cầu bộ xử lý GPU hoặc TPU để chạy mô hình Gemma. Thời gian chạy Python trên CPU Colab tiêu chuẩn và thời gian chạy Python trên GPU T4 là đủ để chạy các mô hình có kích thước Gemma 1B, 2B và 4B. Đối với các trường hợp sử dụng nâng cao cho GPU hoặc TPU khác, vui lòng tham khảo README trong kho lưu trữ Gemma PyTorch.
Truy cập vào Gemma trên Kaggle
Để hoàn tất hướng dẫn này, trước tiên, bạn cần làm theo hướng dẫn thiết lập tại phần Thiết lập Gemma. Phần này sẽ hướng dẫn bạn cách thực hiện những việc sau:
- Truy cập vào Gemma trên kaggle.com.
- Chọn một môi trường thời gian chạy Colab có đủ tài nguyên để chạy mô hình Gemma.
- Tạo và định cấu hình tên người dùng và khoá API của Kaggle.
Sau khi hoàn tất việc thiết lập Gemma, hãy chuyển sang phần tiếp theo để thiết lập biến môi trường cho môi trường Colab.
Đặt các biến môi trường
Đặt biến môi trường cho KAGGLE_USERNAME
và KAGGLE_KEY
. Khi bạn nhận được thông báo "Cấp quyền truy cập?", hãy đồng ý cấp quyền truy cập vào khoá bí mật.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Cài đặt phần phụ thuộc
pip install -q -U torch immutabledict sentencepiece
Tải trọng số mô hình xuống
# Choose variant and machine type
VARIANT = '4b-it'
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT[:2]
if CONFIG == '4b':
CONFIG = '4b-v1'
import kagglehub
# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')
Đặt trình phân tích cú pháp và đường dẫn điểm kiểm tra cho mô hình.
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'
# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'
Định cấu hình môi trường chạy
Các phần sau đây giải thích cách chuẩn bị môi trường PyTorch để chạy Gemma.
Chuẩn bị môi trường chạy PyTorch
Chuẩn bị môi trường thực thi mô hình PyTorch bằng cách nhân bản kho lưu trữ Gemma Pytorch.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'... remote: Enumerating objects: 239, done. remote: Counting objects: 100% (123/123), done. remote: Compressing objects: 100% (68/68), done. remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116 Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done. Resolving deltas: 100% (135/135), done.
import sys
sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM
import os
import torch
Đặt cấu hình mô hình
Trước khi chạy mô hình, bạn phải thiết lập một số tham số cấu hình, bao gồm biến thể Gemma, trình phân tích cú pháp và cấp độ lượng tử hoá.
# Set up model config.
model_config = get_model_config(VARIANT)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path
Định cấu hình ngữ cảnh thiết bị
Mã sau đây định cấu hình ngữ cảnh thiết bị để chạy mô hình:
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
Tạo bản sao và tải mô hình
Tải mô hình bằng trọng số để chuẩn bị chạy các yêu cầu.
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
model = Gemma3ForMultimodalLM(model_config)
model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
model = model.to(device).eval()
print("Model loading done.")
print('Generating requests in chat mode...')
Chạy quy trình suy luận
Dưới đây là ví dụ về cách tạo ở chế độ trò chuyện và tạo bằng nhiều yêu cầu.
Các mô hình Gemma được điều chỉnh hướng dẫn được huấn luyện bằng một trình định dạng cụ thể chú thích các ví dụ về việc điều chỉnh hướng dẫn bằng thông tin bổ sung, cả trong quá trình huấn luyện và suy luận. Chú giải (1) cho biết vai trò trong cuộc trò chuyện và (2) phân định lượt nói trong cuộc trò chuyện.
Các mã thông báo chú thích có liên quan là:
user
: lượt người dùngmodel
: lượt mô hình<start_of_turn>
: bắt đầu lượt thoại<start_of_image>
: thẻ để nhập dữ liệu hình ảnh<end_of_turn><eos>
: kết thúc lượt thoại
Để biết thêm thông tin, hãy đọc về cách định dạng câu lệnh cho các mô hình Gemma được điều chỉnh theo hướng dẫn [tại đây](https://ai.google.dev/gemma/core/prompt-structure
Tạo văn bản có văn bản
Sau đây là một đoạn mã mẫu minh hoạ cách định dạng lời nhắc cho mô hình Gemma được điều chỉnh theo hướng dẫn bằng cách sử dụng mẫu trò chuyện của người dùng và mô hình trong cuộc trò chuyện nhiều lượt.
# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"
# Sample formatted prompt
prompt = (
USER_CHAT_TEMPLATE.format(
prompt='What is a good place for travel in the US?'
)
+ MODEL_CHAT_TEMPLATE.format(prompt='California.')
+ USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
+ '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)
model.generate(
USER_CHAT_TEMPLATE.format(prompt=prompt),
device=device,
output_len=256,
)
Chat prompt: <start_of_turn>user What is a good place for travel in the US?<end_of_turn><eos> <start_of_turn>model California.<end_of_turn><eos> <start_of_turn>user What can I do in California?<end_of_turn><eos> <start_of_turn>model "California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo) \n\nThe more you tell me, the better recommendations I can give! 😊 \n<end_of_turn>"
# Generate sample
model.generate(
'Write a poem about an llm writing a poem.',
device=device,
output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"
Tạo văn bản có hình ảnh
Với Gemma phiên bản 3 trở lên, bạn có thể sử dụng hình ảnh cùng với câu lệnh. Ví dụ sau đây cho bạn biết cách đưa dữ liệu hình ảnh vào câu lệnh.
print('Chat with images...\n')
def read_image(url):
import io
import requests
import PIL
contents = io.BytesIO(requests.get(url).content)
return PIL.Image.open(contents)
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
image = read_image(image_url)
print(model.generate(
[['<start_of_turn>user\n',image, 'What animal is in this image?<end_of_turn>\n', '<start_of_turn>model\n']],
device=device,
output_len=OUTPUT_LEN,
))
Tìm hiểu thêm
Giờ đây, khi đã tìm hiểu cách sử dụng Gemma trong Pytorch, bạn có thể khám phá nhiều điều khác mà Gemma có thể làm được tại ai.google.dev/gemma. Hãy xem thêm các tài nguyên có liên quan sau: