![]() |
![]() |
![]() |
Z tego przewodnika dowiesz się, jak uruchomić model Gemma za pomocą platformy PyTorch, w tym jak używać danych obrazów do promptowania modeli Gemma w wersji 3 i nowszych. Więcej informacji o implementacji Gemmy w PyTorch znajdziesz w pliku README w repozytorium projektu.
Konfiguracja
W sekcjach poniżej znajdziesz informacje o konfigurowaniu środowiska programistycznego, w tym o tym, jak uzyskać dostęp do modeli Gemma do pobrania z Kaggle, ustawić zmienne uwierzytelniania, zainstalować zależności i zaimportować pakiety.
Wymagania systemowe
Ta biblioteka Gemma Pytorch wymaga procesorów GPU lub TPU do uruchomienia modelu Gemma. Standardowe środowisko wykonawcze Pythona na procesorze i środowisko wykonawcze Pythona na procesorze graficznym T4 w Colab są wystarczające do uruchamiania modeli Gemma o rozmiarach 1B, 2B i 4B. W przypadku zaawansowanych zastosowań innych procesorów graficznych lub TPU zapoznaj się z plikiem README w repozytorium Gemma PyTorch.
Uzyskiwanie dostępu do Gemy w Kaggle
Aby ukończyć ten samouczek, musisz najpierw wykonać instrukcje konfiguracji podane w artykule Konfiguracja Gemmy, z którego dowiesz się, jak wykonać te czynności:
- Uzyskaj dostęp do Gemy na Kaggle.
- Wybierz środowisko wykonawcze Colab z zasobami wystarczającymi do uruchomienia modelu Gemma.
- Wygeneruj i skonfiguruj nazwę użytkownika i klucz interfejsu API Kaggle.
Po zakończeniu konfiguracji Gemy przejdź do następnej sekcji, w której ustawisz zmienne środowiskowe dla środowiska Colab.
Ustawianie zmiennych środowiskowych
Ustaw zmienne środowiskowe dla KAGGLE_USERNAME
i KAGGLE_KEY
. Gdy pojawi się prośba „Przyznać dostęp?”, wyraź zgodę na przyznanie dostępu do obiektu tajnego.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Instalowanie zależności
pip install -q -U torch immutabledict sentencepiece
Pobieranie wag modelu
# Choose variant and machine type
VARIANT = '4b-it'
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT.split('-')[0]
import kagglehub
# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')
Ustaw ścieżki tokenizera i punktu kontrolnego dla modelu.
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'
# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'
Konfigurowanie środowiska wykonawczego
W kolejnych sekcjach dowiesz się, jak przygotować środowisko PyTorch do uruchamiania modelu Gemma.
Przygotowywanie środowiska wykonawczego PyTorch
Przygotuj środowisko wykonawcze modelu PyTorch, klonując repozytorium Gemma Pytorch.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'... remote: Enumerating objects: 239, done. remote: Counting objects: 100% (123/123), done. remote: Compressing objects: 100% (68/68), done. remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116 Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done. Resolving deltas: 100% (135/135), done.
import sys
sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM
import os
import torch
Ustaw konfigurację modelu
Zanim uruchomisz model, musisz ustawić kilka parametrów konfiguracji, w tym wariant Gemmy, tokenizator i poziom kwantyzacji.
# Set up model config.
model_config = get_model_config(CONFIG)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path
Konfigurowanie kontekstu urządzenia
Ten kod konfiguruje kontekst urządzenia do uruchomienia modelu:
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
Utwórz instancję modelu i wczytaj go
Załaduj model z wagami, aby przygotować się do uruchamiania żądań.
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
model = Gemma3ForMultimodalLM(model_config)
model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
model = model.to(device).eval()
print("Model loading done.")
print('Generating requests in chat mode...')
Uruchamianie wnioskowania
Poniżej znajdziesz przykłady generowania w trybie czatu i generowania za pomocą wielu żądań.
Modele Gemma dostrojone pod kątem instrukcji zostały wytrenowane przy użyciu specjalnego formatera, który dodaje do przykładów dostrajania pod kątem instrukcji dodatkowe informacje zarówno podczas trenowania, jak i wnioskowania. Adnotacje (1) wskazują role w rozmowie i (2) wyznaczają kolejne wypowiedzi.
Odpowiednie tokeny adnotacji to:
user
: kolejka użytkownikamodel
: kolejka modelu<start_of_turn>
: początek kolejki w dialogu<start_of_image>
: tag do wprowadzania danych obrazu<end_of_turn><eos>
: koniec tury dialogu
Więcej informacji o formatowaniu promptów w przypadku modeli Gemma dostosowanych do instrukcji znajdziesz tutaj.
Generowanie tekstu za pomocą tekstu
Poniżej znajdziesz przykładowy fragment kodu, który pokazuje, jak sformatować prompta dla modelu Gemma dostosowanego do instrukcji za pomocą szablonów czatu użytkownika i modelu w wielokrotnej rozmowie.
# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"
# Sample formatted prompt
prompt = (
USER_CHAT_TEMPLATE.format(
prompt='What is a good place for travel in the US?'
)
+ MODEL_CHAT_TEMPLATE.format(prompt='California.')
+ USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
+ '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)
model.generate(
USER_CHAT_TEMPLATE.format(prompt=prompt),
device=device,
output_len=256,
)
Chat prompt: <start_of_turn>user What is a good place for travel in the US?<end_of_turn><eos> <start_of_turn>model California.<end_of_turn><eos> <start_of_turn>user What can I do in California?<end_of_turn><eos> <start_of_turn>model "California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo) \n\nThe more you tell me, the better recommendations I can give! 😊 \n<end_of_turn>"
# Generate sample
model.generate(
'Write a poem about an llm writing a poem.',
device=device,
output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"
Generowanie tekstu z obrazami
W przypadku modeli Gemma w wersji 3 i nowszych możesz używać obrazów w prompcie. Poniższy przykład pokazuje, jak dołączyć dane wizualne do promptu.
print('Chat with images...\n')
def read_image(url):
import io
import requests
import PIL
contents = io.BytesIO(requests.get(url).content)
return PIL.Image.open(contents)
image = read_image(
'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
)
print(model.generate(
[
[
'<start_of_turn>user\n',
image,
'What animal is in this image?<end_of_turn>\n',
'<start_of_turn>model\n'
]
],
device=device,
output_len=256,
))
Więcej informacji
Teraz, gdy wiesz już, jak używać Gemmy w PyTorch, możesz poznać wiele innych możliwości tego modelu na stronie ai.google.dev/gemma.
Zapoznaj się też z tymi materiałami: