Chạy Gemma bằng PyTorch

Xem trên ai.google.dev Chạy trong Google Colab Xem nguồn trên GitHub

Hướng dẫn này sẽ hướng dẫn bạn cách chạy Gemma bằng khung PyTorch, bao gồm cả cách sử dụng dữ liệu hình ảnh để nhắc các mô hình Gemma phiên bản 3 trở lên. Để biết thêm thông tin chi tiết về việc triển khai Gemma PyTorch, hãy xem README của kho lưu trữ dự án.

Thiết lập

Các phần sau đây giải thích cách thiết lập môi trường phát triển, bao gồm cả cách truy cập vào các mô hình Gemma để tải xuống từ Kaggle, thiết lập các biến xác thực, cài đặt các phần phụ thuộc và nhập các gói.

Yêu cầu hệ thống

Thư viện Gemma Pytorch này yêu cầu bộ xử lý GPU hoặc TPU để chạy mô hình Gemma. Thời gian chạy Python CPU Colab tiêu chuẩn và thời gian chạy Python GPU T4 là đủ để chạy các mô hình có kích thước 1B, 2B và 4B của Gemma. Đối với các trường hợp sử dụng nâng cao cho GPU hoặc TPU khác, vui lòng tham khảo README trong kho lưu trữ Gemma PyTorch.

Truy cập Gemma trên Kaggle

Để hoàn tất hướng dẫn này, trước tiên, bạn cần làm theo hướng dẫn thiết lập tại Thiết lập Gemma. Hướng dẫn này sẽ cho bạn biết cách thực hiện những việc sau:

  • Truy cập vào Gemma trên Kaggle.
  • Chọn một thời gian chạy Colab có đủ tài nguyên để chạy mô hình Gemma.
  • Tạo và định cấu hình tên người dùng và khoá API Kaggle.

Sau khi hoàn tất quá trình thiết lập Gemma, hãy chuyển sang phần tiếp theo, nơi bạn sẽ thiết lập các biến môi trường cho môi trường Colab.

Đặt các biến môi trường

Đặt các biến môi trường cho KAGGLE_USERNAMEKAGGLE_KEY. Khi được nhắc bằng thông báo "Bạn có muốn cấp quyền truy cập không?", hãy đồng ý cấp quyền truy cập vào khoá bí mật.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Cài đặt các phần phụ thuộc

pip install -q -U torch immutabledict sentencepiece

Tải trọng số mô hình xuống

# Choose variant and machine type
VARIANT = '4b-it' 
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT.split('-')[0]
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')

Đặt đường dẫn của trình mã hoá từ và điểm kiểm tra cho mô hình.

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Định cấu hình môi trường chạy

Các phần sau đây giải thích cách chuẩn bị môi trường PyTorch để chạy Gemma.

Chuẩn bị môi trường chạy PyTorch

Chuẩn bị môi trường thực thi mô hình PyTorch bằng cách sao chép kho lưu trữ Gemma Pytorch.

git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM

import os
import torch

Đặt cấu hình mô hình

Trước khi chạy mô hình, bạn phải thiết lập một số thông số cấu hình, bao gồm cả biến thể Gemma, mã hoá từ và mức độ lượng tử hoá.

# Set up model config.
model_config = get_model_config(CONFIG)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path

Định cấu hình bối cảnh thiết bị

Đoạn mã sau đây định cấu hình ngữ cảnh thiết bị để chạy mô hình:

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

Tạo thực thể và tải mô hình

Tải mô hình cùng với các trọng số của mô hình để chuẩn bị chạy các yêu cầu.

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
    model = Gemma3ForMultimodalLM(model_config)
    model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
    model = model.to(device).eval()
print("Model loading done.")

print('Generating requests in chat mode...')

Chạy suy luận

Dưới đây là ví dụ về cách tạo ở chế độ trò chuyện và tạo bằng nhiều yêu cầu.

Các mô hình Gemma được tinh chỉnh theo hướng dẫn được huấn luyện bằng một trình định dạng cụ thể, chú thích các ví dụ tinh chỉnh theo hướng dẫn bằng thông tin bổ sung, cả trong quá trình huấn luyện và suy luận. Chú thích (1) cho biết vai trò trong cuộc trò chuyện và (2) phân định lượt trong cuộc trò chuyện.

Các mã thông báo chú thích có liên quan là:

  • user: lượt của người dùng
  • model: lượt của mô hình
  • <start_of_turn>: bắt đầu lượt hội thoại
  • <start_of_image>: thẻ để nhập dữ liệu hình ảnh
  • <end_of_turn><eos>: cuối lượt trò chuyện

Để biết thêm thông tin, hãy đọc về định dạng câu lệnh cho các mô hình Gemma được điều chỉnh theo hướng dẫn tại đây.

Tạo văn bản bằng văn bản

Sau đây là một đoạn mã mẫu minh hoạ cách định dạng câu lệnh cho một mô hình Gemma được tinh chỉnh theo hướng dẫn bằng cách sử dụng mẫu trò chuyện của người dùng và mô hình trong một cuộc trò chuyện nhiều lượt.

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=256,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

Tạo văn bản bằng hình ảnh

Với Gemma phiên bản 3 trở lên, bạn có thể sử dụng hình ảnh trong câu lệnh. Ví dụ sau đây cho thấy cách đưa dữ liệu trực quan vào câu lệnh.

print('Chat with images...\n')

def read_image(url):
    import io
    import requests
    import PIL

    contents = io.BytesIO(requests.get(url).content)
    return PIL.Image.open(contents)

image = read_image(
    'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
)

print(model.generate(
    [
        [
            '<start_of_turn>user\n',
            image,
            'What animal is in this image?<end_of_turn>\n',
            '<start_of_turn>model\n'
        ]
    ],
    device=device,
    output_len=256,
))

Tìm hiểu thêm

Giờ đây, bạn đã học được cách sử dụng Gemma trong Pytorch, bạn có thể khám phá nhiều việc khác mà Gemma có thể làm tại ai.google.dev/gemma.

Bạn cũng có thể tham khảo các tài nguyên khác có liên quan sau đây: