Uruchamianie Gemma za pomocą PyTorch

Wyświetl na ai.google.dev Uruchom w Google Colab Wyświetl źródło w GitHubie

Z tego przewodnika dowiesz się, jak uruchomić model Gemma za pomocą platformy PyTorch, w tym jak używać danych obrazów do promptowania modeli Gemma w wersji 3 i nowszych. Więcej informacji o implementacji Gemmy w PyTorch znajdziesz w pliku README w repozytorium projektu.

Konfiguracja

W sekcjach poniżej znajdziesz informacje o konfigurowaniu środowiska programistycznego, w tym o tym, jak uzyskać dostęp do modeli Gemma do pobrania z Kaggle, ustawić zmienne uwierzytelniania, zainstalować zależności i zaimportować pakiety.

Wymagania systemowe

Ta biblioteka Gemma Pytorch wymaga procesorów GPU lub TPU do uruchomienia modelu Gemma. Standardowe środowisko wykonawcze Pythona na procesorze i środowisko wykonawcze Pythona na procesorze graficznym T4 w Colab są wystarczające do uruchamiania modeli Gemma o rozmiarach 1B, 2B i 4B. W przypadku zaawansowanych zastosowań innych procesorów graficznych lub TPU zapoznaj się z plikiem README w repozytorium Gemma PyTorch.

Uzyskiwanie dostępu do Gemy w Kaggle

Aby ukończyć ten samouczek, musisz najpierw wykonać instrukcje konfiguracji podane w artykule Konfiguracja Gemmy, z którego dowiesz się, jak wykonać te czynności:

  • Uzyskaj dostęp do Gemy na Kaggle.
  • Wybierz środowisko wykonawcze Colab z zasobami wystarczającymi do uruchomienia modelu Gemma.
  • Wygeneruj i skonfiguruj nazwę użytkownika i klucz interfejsu API Kaggle.

Po zakończeniu konfiguracji Gemy przejdź do następnej sekcji, w której ustawisz zmienne środowiskowe dla środowiska Colab.

Ustawianie zmiennych środowiskowych

Ustaw zmienne środowiskowe dla KAGGLE_USERNAMEKAGGLE_KEY. Gdy pojawi się prośba „Przyznać dostęp?”, wyraź zgodę na przyznanie dostępu do obiektu tajnego.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Instalowanie zależności

pip install -q -U torch immutabledict sentencepiece

Pobieranie wag modelu

# Choose variant and machine type
VARIANT = '4b-it' 
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT.split('-')[0]
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')

Ustaw ścieżki tokenizera i punktu kontrolnego dla modelu.

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Konfigurowanie środowiska wykonawczego

W kolejnych sekcjach dowiesz się, jak przygotować środowisko PyTorch do uruchamiania modelu Gemma.

Przygotowywanie środowiska wykonawczego PyTorch

Przygotuj środowisko wykonawcze modelu PyTorch, klonując repozytorium Gemma Pytorch.

git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM

import os
import torch

Ustaw konfigurację modelu

Zanim uruchomisz model, musisz ustawić kilka parametrów konfiguracji, w tym wariant Gemmy, tokenizator i poziom kwantyzacji.

# Set up model config.
model_config = get_model_config(CONFIG)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path

Konfigurowanie kontekstu urządzenia

Ten kod konfiguruje kontekst urządzenia do uruchomienia modelu:

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

Utwórz instancję modelu i wczytaj go

Załaduj model z wagami, aby przygotować się do uruchamiania żądań.

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
    model = Gemma3ForMultimodalLM(model_config)
    model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
    model = model.to(device).eval()
print("Model loading done.")

print('Generating requests in chat mode...')

Uruchamianie wnioskowania

Poniżej znajdziesz przykłady generowania w trybie czatu i generowania za pomocą wielu żądań.

Modele Gemma dostrojone pod kątem instrukcji zostały wytrenowane przy użyciu specjalnego formatera, który dodaje do przykładów dostrajania pod kątem instrukcji dodatkowe informacje zarówno podczas trenowania, jak i wnioskowania. Adnotacje (1) wskazują role w rozmowie i (2) wyznaczają kolejne wypowiedzi.

Odpowiednie tokeny adnotacji to:

  • user: kolejka użytkownika
  • model: kolejka modelu
  • <start_of_turn>: początek kolejki w dialogu
  • <start_of_image>: tag do wprowadzania danych obrazu
  • <end_of_turn><eos>: koniec tury dialogu

Więcej informacji o formatowaniu promptów w przypadku modeli Gemma dostosowanych do instrukcji znajdziesz tutaj.

Generowanie tekstu za pomocą tekstu

Poniżej znajdziesz przykładowy fragment kodu, który pokazuje, jak sformatować prompta dla modelu Gemma dostosowanego do instrukcji za pomocą szablonów czatu użytkownika i modelu w wielokrotnej rozmowie.

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=256,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

Generowanie tekstu z obrazami

W przypadku modeli Gemma w wersji 3 i nowszych możesz używać obrazów w prompcie. Poniższy przykład pokazuje, jak dołączyć dane wizualne do promptu.

print('Chat with images...\n')

def read_image(url):
    import io
    import requests
    import PIL

    contents = io.BytesIO(requests.get(url).content)
    return PIL.Image.open(contents)

image = read_image(
    'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
)

print(model.generate(
    [
        [
            '<start_of_turn>user\n',
            image,
            'What animal is in this image?<end_of_turn>\n',
            '<start_of_turn>model\n'
        ]
    ],
    device=device,
    output_len=256,
))

Więcej informacji

Teraz, gdy wiesz już, jak używać Gemmy w PyTorch, możesz poznać wiele innych możliwości tego modelu na stronie ai.google.dev/gemma.

Zapoznaj się też z tymi materiałami: