Gemma in PyTorch

Auf ai.google.dev ansehen In Google Colab ausführen Quelle auf GitHub ansehen

In dieser kurzen Demo wird die Gemma-Inferenz in PyTorch ausgeführt. Weitere Informationen finden Sie im GitHub-Repository der offiziellen PyTorch-Implementierung.

Hinweis:

  • Die kostenlose Python-Laufzeit für Colab-CPUs und T4-GPUs reicht aus, um die Gemma 2B-Modelle und die quantisierten 7B-Modelle mit int8 auszuführen.
  • Informationen zu erweiterten Anwendungsfällen für andere GPUs oder TPUs finden Sie in der README.md im offiziellen Repository.

1. Kaggle-Zugriff für Gemma einrichten

Um diese Anleitung abzuschließen, müssen Sie zuerst der Anleitung unter Gemma-Einrichtung folgen. Sie erfahren, wie Sie Folgendes tun:

  • Sie können Gemma unter kaggle.com nutzen.
  • Wählen Sie eine Colab-Laufzeit mit ausreichenden Ressourcen zum Ausführen des Gemma-Modells aus.
  • Erstellen und konfigurieren Sie einen Kaggle-Nutzernamen und einen API-Schlüssel.

Nachdem Sie die Gemma-Einrichtung abgeschlossen haben, fahren Sie mit dem nächsten Abschnitt fort, in dem Sie Umgebungsvariablen für Ihre Colab-Umgebung festlegen.

2. Umgebungsvariablen festlegen

Legen Sie Umgebungsvariablen für KAGGLE_USERNAME und KAGGLE_KEY fest. Wenn Sie die Aufforderung „Zugriff gewähren?“ sehen, stimmen Sie zu, den geheimen Zugriff zu gewähren.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Abhängigkeiten installieren

pip install -q -U torch immutabledict sentencepiece

Modellgewichte herunterladen

# Choose variant and machine type
VARIANT = '2b-it'
MACHINE_TYPE = 'cuda'

CONFIG = VARIANT[:2]
if CONFIG == '2b':
  CONFIG = '2b-v2'
import os
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-2/pyTorch/gemma-2-{VARIANT}')
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Modellimplementierung herunterladen

# NOTE: The "installation" is just cloning the repo.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch')
from gemma.config import GemmaConfig, get_model_config
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

Modell einrichten

# Set up model config.
model_config = get_model_config(CONFIG)
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

Inferenz ausführen

Unten findest du Beispiele für das Generieren im Chatmodus und das Generieren mit mehreren Anfragen.

Die mithilfe der Anleitung optimierten Gemma-Modelle wurden mit einem speziellen Formatierer trainiert, der sowohl während des Trainings als auch der Inferenz Beispiele für die Anleitungsoptimierung mit zusätzlichen Informationen annotiert. Die Anmerkungen (1) geben die Rollen in einer Unterhaltung an und (2) grenzen die Gesprächsrunden in einer Unterhaltung ab.

Die relevanten Anmerkungstokens sind:

  • user: Nutzer ist an der Reihe
  • model: Modelldrehung
  • <start_of_turn>: Beginn eines Dialogschritts
  • <end_of_turn><eos>: Ende des Dialogzugs

Weitere Informationen zur Prompt-Formatierung für Gemma-Modelle, die anhand von Anweisungen optimiert wurden, finden Sie hier.

Im folgenden Code-Snippet wird gezeigt, wie ein Prompt für ein anhand von Anleitungen optimiertes Gemma-Modell mithilfe von Nutzer- und Modell-Chatvorlagen in einer Unterhaltung mit mehreren Antworten formatiert wird.

# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=128,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

Weitere Informationen

Nachdem Sie nun gelernt haben, wie Sie Gemma in PyTorch verwenden, können Sie sich unter ai.google.dev/gemma über die vielen anderen Möglichkeiten informieren, die Gemma bietet. Weitere Informationen finden Sie hier: