In diesem Leitfaden erfahren Sie, wie Sie Gemma mit dem PyTorch-Framework ausführen, einschließlich der Verwendung von Bilddaten für die Prompts von Gemma Release 3 und höher. Weitere Informationen zur Gemma PyTorch-Implementierung finden Sie in der README-Datei des Projekt-Repositories.
Einrichtung
In den folgenden Abschnitten wird beschrieben, wie Sie Ihre Entwicklungsumgebung einrichten. Dazu gehört auch, wie Sie Zugriff auf Gemma-Modelle für den Download von Kaggle erhalten, Authentifizierungsvariablen festlegen, Abhängigkeiten installieren und Pakete importieren.
Systemanforderungen
Für diese Gemma-Pytorch-Bibliothek sind GPU- oder TPU-Prozessoren erforderlich, um das Gemma-Modell auszuführen. Die standardmäßige Python-Laufzeit für Colab-CPUs und T4-GPUs reicht aus, um Gemma-Modelle der Größe 1 B, 2 B und 4 B auszuführen. Informationen zu erweiterten Anwendungsfällen für andere GPUs oder TPUs finden Sie in der README im Gemma PyTorch-Repository.
Zugriff auf Gemma auf Kaggle erhalten
Bevor Sie mit dieser Anleitung fortfahren können, müssen Sie zuerst die Einrichtungsanleitung unter Gemma einrichten befolgen. Dort erfahren Sie, wie Sie Folgendes tun:
- Sie können Gemma unter kaggle.com nutzen.
- Wählen Sie eine Colab-Laufzeit mit ausreichenden Ressourcen für die Ausführung des Gemma-Modells aus.
- Erstellen und konfigurieren Sie einen Kaggle-Nutzernamen und einen API-Schlüssel.
Nachdem Sie die Gemma-Einrichtung abgeschlossen haben, fahren Sie mit dem nächsten Abschnitt fort, in dem Sie Umgebungsvariablen für Ihre Colab-Umgebung festlegen.
Umgebungsvariablen festlegen
Legen Sie Umgebungsvariablen für KAGGLE_USERNAME
und KAGGLE_KEY
fest. Wenn Sie die Meldung „Zugriff gewähren?“ sehen, stimmen Sie zu, den geheimen Zugriff zu gewähren.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Abhängigkeiten installieren
pip install -q -U torch immutabledict sentencepiece
Modellgewichte herunterladen
# Choose variant and machine type
VARIANT = '4b-it'
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT[:2]
if CONFIG == '4b':
CONFIG = '4b-v1'
import kagglehub
# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')
Legen Sie die Pfade für den Tokenizer und den Checkpoint für das Modell fest.
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'
# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'
Ausführungsumgebung konfigurieren
In den folgenden Abschnitten wird erläutert, wie Sie eine PyTorch-Umgebung für die Ausführung von Gemma vorbereiten.
PyTorch-Laufumgebung vorbereiten
Bereiten Sie die PyTorch-Ausführungsumgebung für das Modell vor, indem Sie das Gemma PyTorch-Repository klonen.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'... remote: Enumerating objects: 239, done. remote: Counting objects: 100% (123/123), done. remote: Compressing objects: 100% (68/68), done. remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116 Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done. Resolving deltas: 100% (135/135), done.
import sys
sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM
import os
import torch
Modellkonfiguration festlegen
Bevor Sie das Modell ausführen, müssen Sie einige Konfigurationsparameter festlegen, darunter die Gemma-Variante, den Tokenisierer und die Quantisierungsebene.
# Set up model config.
model_config = get_model_config(VARIANT)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path
Gerätekontext konfigurieren
Im folgenden Code wird der Gerätekontext für die Ausführung des Modells konfiguriert:
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
Modell instanziieren und laden
Laden Sie das Modell mit seinen Gewichten, um Anfragen auszuführen.
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
model = Gemma3ForMultimodalLM(model_config)
model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
model = model.to(device).eval()
print("Model loading done.")
print('Generating requests in chat mode...')
Inferenz ausführen
Unten finden Sie Beispiele für die Generierung im Chatmodus und mit mehreren Anfragen.
Die Gemma-Modelle mit Anleitungsoptimierung wurden mit einem speziellen Formatierer trainiert, der Beispiele für die Anleitungsoptimierung sowohl während des Trainings als auch der Inferenz mit zusätzlichen Informationen annotiert. Die Anmerkungen (1) geben die Rollen in einer Unterhaltung an und (2) grenzen die Gesprächsrunden in einer Unterhaltung ab.
Die relevanten Anmerkungstokens sind:
user
: Nutzer ist an der Reihemodel
: Modelldrehung<start_of_turn>
: Beginn des Dialogschritts<start_of_image>
: Tag für die Eingabe von Bilddaten<end_of_turn><eos>
: Ende des Dialogschritts
Weitere Informationen zur Prompt-Formatierung für Gemma-Modelle, die anhand von Anweisungen optimiert wurden, finden Sie [hier](https://ai.google.dev/gemma/core/prompt-structure).
Text mit Text generieren
Im folgenden Code-Snippet wird gezeigt, wie ein Prompt für ein anhand von Anleitungen optimiertes Gemma-Modell mithilfe von Nutzer- und Modell-Chatvorlagen in einer Unterhaltung mit mehreren Gesprächsrunden formatiert wird.
# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"
# Sample formatted prompt
prompt = (
USER_CHAT_TEMPLATE.format(
prompt='What is a good place for travel in the US?'
)
+ MODEL_CHAT_TEMPLATE.format(prompt='California.')
+ USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
+ '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)
model.generate(
USER_CHAT_TEMPLATE.format(prompt=prompt),
device=device,
output_len=256,
)
Chat prompt: <start_of_turn>user What is a good place for travel in the US?<end_of_turn><eos> <start_of_turn>model California.<end_of_turn><eos> <start_of_turn>user What can I do in California?<end_of_turn><eos> <start_of_turn>model "California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo) \n\nThe more you tell me, the better recommendations I can give! 😊 \n<end_of_turn>"
# Generate sample
model.generate(
'Write a poem about an llm writing a poem.',
device=device,
output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"
Text mit Bildern generieren
Ab Gemma-Version 3 können Sie Bilder mit Ihrem Prompt verwenden. Im folgenden Beispiel wird gezeigt, wie Sie visuelle Daten in Ihren Prompt einfügen.
print('Chat with images...\n')
def read_image(url):
import io
import requests
import PIL
contents = io.BytesIO(requests.get(url).content)
return PIL.Image.open(contents)
image_url = 'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
image = read_image(image_url)
print(model.generate(
[['<start_of_turn>user\n',image, 'What animal is in this image?<end_of_turn>\n', '<start_of_turn>model\n']],
device=device,
output_len=OUTPUT_LEN,
))
Weitere Informationen
Nachdem Sie nun gelernt haben, wie Sie Gemma in PyTorch verwenden, können Sie sich unter ai.google.dev/gemma über die vielen anderen Möglichkeiten informieren, die Gemma bietet. Weitere Informationen finden Sie hier: