Xem trên ai.google.dev | Chạy trong Google Colab | Xem nguồn trên GitHub |
Đây là bản minh hoạ nhanh về cách chạy suy luận của Gemma trong PyTorch. Để biết thêm chi tiết, vui lòng xem kho lưu trữ GitHub của hoạt động triển khai PyTorch chính thức tại đây.
Lưu ý rằng:
- Môi trường thời gian chạy miễn phí Colab CPU Python và GPU T4 là đủ để chạy các mô hình Gemma 2B và các mô hình lượng tử hóa 7B int8.
- Đối với các trường hợp sử dụng nâng cao cho GPU hoặc TPU khác, vui lòng tham khảo README.md trong kho lưu trữ chính thức.
1. Thiết lập quyền truy cập vào Kaggle cho Gemma
Để hoàn tất hướng dẫn này, trước tiên bạn cần làm theo hướng dẫn thiết lập trong phần thiết lập Gemma. Các hướng dẫn này sẽ cho bạn biết cách thực hiện những thao tác sau:
- Truy cập vào Gemma trên kaggle.com.
- Chọn một môi trường thời gian chạy Colab có đủ tài nguyên để chạy mô hình Gemma.
- Tạo và định cấu hình tên người dùng và khoá API của Kaggle.
Sau khi hoàn tất việc thiết lập Gemma, hãy chuyển sang phần tiếp theo để thiết lập biến môi trường cho môi trường Colab.
2. Đặt các biến môi trường
Thiết lập các biến môi trường cho KAGGLE_USERNAME
và KAGGLE_KEY
. Khi được nhắc với thông báo "Cấp quyền truy cập?", hãy đồng ý cung cấp quyền truy cập bí mật.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Cài đặt phần phụ thuộc
pip install -q -U torch immutabledict sentencepiece
Tải trọng số mô hình xuống
# Choose variant and machine type
VARIANT = '2b-it'
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT[:2]
if CONFIG == '2b':
CONFIG = '2b-v2'
import os
import kagglehub
# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-2/pyTorch/gemma-2-{VARIANT}')
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'
# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'
Tải phần triển khai mô hình xuống
# NOTE: The "installation" is just cloning the repo.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'... remote: Enumerating objects: 239, done. remote: Counting objects: 100% (123/123), done. remote: Compressing objects: 100% (68/68), done. remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116 Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done. Resolving deltas: 100% (135/135), done.
import sys
sys.path.append('gemma_pytorch')
from gemma.config import GemmaConfig, get_model_config
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch
Thiết lập mô hình
# Set up model config.
model_config = get_model_config(CONFIG)
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT
# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()
Chạy quy trình suy luận
Dưới đây là ví dụ về cách tạo ở chế độ trò chuyện và tạo bằng nhiều yêu cầu.
Các mô hình Gemma được điều chỉnh hướng dẫn được huấn luyện bằng một trình định dạng cụ thể chú thích các ví dụ về việc điều chỉnh hướng dẫn bằng thông tin bổ sung, cả trong quá trình huấn luyện và suy luận. Chú giải (1) cho biết vai trò trong cuộc trò chuyện và (2) phân định lượt nói trong cuộc trò chuyện.
Các mã thông báo chú thích có liên quan là:
user
: lượt người dùngmodel
: lượt mô hình<start_of_turn>
: bắt đầu lượt thoại<end_of_turn><eos>
: kết thúc lượt thoại
Để biết thêm thông tin, hãy đọc về định dạng lời nhắc cho hướng dẫn về các mô hình Gemma được điều chỉnh tại đây.
Sau đây là một đoạn mã mẫu minh hoạ cách định dạng lời nhắc cho mô hình Gemma được điều chỉnh theo hướng dẫn bằng cách sử dụng mẫu trò chuyện của người dùng và mô hình trong một cuộc trò chuyện nhiều lượt.
# Generate with one request in chat mode
# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"
# Sample formatted prompt
prompt = (
USER_CHAT_TEMPLATE.format(
prompt='What is a good place for travel in the US?'
)
+ MODEL_CHAT_TEMPLATE.format(prompt='California.')
+ USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
+ '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)
model.generate(
USER_CHAT_TEMPLATE.format(prompt=prompt),
device=device,
output_len=128,
)
Chat prompt: <start_of_turn>user What is a good place for travel in the US?<end_of_turn><eos> <start_of_turn>model California.<end_of_turn><eos> <start_of_turn>user What can I do in California?<end_of_turn><eos> <start_of_turn>model "California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo) \n\nThe more you tell me, the better recommendations I can give! 😊 \n<end_of_turn>"
# Generate sample
model.generate(
'Write a poem about an llm writing a poem.',
device=device,
output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"
Tìm hiểu thêm
Giờ đây, khi đã tìm hiểu cách sử dụng Gemma trong Pytorch, bạn có thể khám phá nhiều điều khác mà Gemma có thể làm trong ai.google.dev/gemma. Ngoài ra, hãy xem các tài nguyên có liên quan khác sau đây: