Auf ai.google.dev ansehen | In Google Colab ausführen | In Vertex AI öffnen | Quelle auf GitHub ansehen |
In dieser Anleitung erfahren Sie, wie Sie mit dem RecurrentGemma 2B Instruct-Modell grundlegende Stichproben/Inferenzen mit der recurrentgemma
-Bibliothek von Google DeepMind durchführen, die mit JAX (einer leistungsstarken numerischen Rechenbibliothek), Flax (der JAX-basierten neuronalen Netzwerkbibliothek), Orbax (einer JAX-basierten Bibliothek für Trainingsprogramme) und {1dePieSence-Token-Token1}-Bibliothek (Checkpointing-Bibliothek) geschrieben wurde.SentencePiece Obwohl Flax in diesem Notizbuch nicht direkt verwendet wird, wurde Flax verwendet, um Gemma und RecurrentGemma (das Griffin-Modell) zu erstellen.
Dieses Notebook kann in Google Colab mit der T4-GPU ausgeführt werden. Rufen Sie dazu Bearbeiten > Notebook-Einstellungen > Hardwarebeschleuniger auf und wählen Sie T4-GPU aus.
Einrichtung
In den folgenden Abschnitten werden die Schritte zum Vorbereiten eines Notebooks für die Verwendung eines RecurrentGemma-Modells erläutert, einschließlich des Modellzugriffs, des Abrufens eines API-Schlüssels und des Konfigurierens der Notebook-Laufzeit
Kaggle-Zugriff für Gemma einrichten
Um diese Anleitung abzuschließen, müssen Sie zuerst der Einrichtungsanleitung folgen, die der Gemma-Einrichtung ähnlich ist. Es gibt allerdings einige Ausnahmen:
- Auf kaggle.com erhältst du Zugriff auf RecurrentGemma (anstelle von Gemma).
- Wählen Sie eine Colab-Laufzeit mit ausreichenden Ressourcen zum Ausführen des RecurrentGemma-Modells aus.
- Generieren und konfigurieren Sie einen Kaggle-Nutzernamen und einen API-Schlüssel.
Nachdem Sie die Einrichtung von RecurrentGemma abgeschlossen haben, fahren Sie mit dem nächsten Abschnitt fort. Dort legen Sie Umgebungsvariablen für Ihre Colab-Umgebung fest.
Umgebungsvariablen festlegen
Legen Sie Umgebungsvariablen für KAGGLE_USERNAME
und KAGGLE_KEY
fest. Wenn die Aufforderung „Zugriff erlauben?“ angezeigt wird, -Nachrichten, erklären Sie sich damit einverstanden, Secret-Zugriff bereitzustellen.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
recurrentgemma
-Bibliothek installieren
In diesem Notebook wird eine kostenlose Colab-GPU verwendet. Klicken Sie zum Aktivieren der Hardwarebeschleunigung auf Bearbeiten > Notebook-Einstellungen > Wählen Sie T4 GPU aus > Klicken Sie auf Speichern.
Als Nächstes müssen Sie die recurrentgemma
-Bibliothek von Google DeepMind von github.com/google-deepmind/recurrentgemma
installieren. Wenn Sie einen Fehler zum „Abhängigkeitsauflöser von pip“ erhalten, können Sie ihn in der Regel ignorieren.
pip install git+https://github.com/google-deepmind/recurrentgemma.git
RecurrentGemma-Modell laden und vorbereiten
- Laden Sie das RecurrentGemma-Modell mit
kagglehub.model_download
. Dafür werden drei Argumente benötigt:
handle
: Das Modell-Handle von Kagglepath
: (optionaler String) der lokale Pfadforce_download
: (optionaler boolescher Wert) Erzwingt das erneute Herunterladen des Modells
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s] Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
- Überprüfen Sie den Speicherort der Modellgewichtungen und des Tokenizers und legen Sie dann die Pfadvariablen fest. Das Tokenizer-Verzeichnis befindet sich im Hauptverzeichnis, in das Sie das Modell heruntergeladen haben, und die Modellgewichtungen befinden sich in einem Unterverzeichnis. Beispiel:
- Die Datei
tokenizer.model
befindet sich in/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
. - Der Modellprüfpunkt befindet sich in
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
.
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
Stichprobenerhebung/Inferenz durchführen
- Laden Sie den Prüfpunkt des RecurrentGemma-Modells mit der Methode
recurrentgemma.jax.load_parameters
. Das auf"single_device"
gesetzte Argumentsharding
lädt alle Modellparameter auf ein einzelnes Gerät.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma
params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
- Laden Sie den Modelltokenizer RecurrentGemma, der mit
sentencepiece.SentencePieceProcessor
erstellt wurde:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- Verwenden Sie
recurrentgemma.GriffinConfig.from_flax_params_or_variables
, um automatisch die richtige Konfiguration aus dem RecurrentGemma-Modellprüfpunkt zu laden. Instanziieren Sie dann das Griffin-Modell mitrecurrentgemma.jax.Griffin
.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
flax_params_or_variables=params)
model = recurrentgemma.Griffin(model_config)
- Erstellen Sie eine
sampler
mitrecurrentgemma.jax.Sampler
zusätzlich zum Prüfpunkt bzw. die Gewichte des RecurrentGemma-Modells und dem Tokenizer:
sampler = recurrentgemma.Sampler(
model=model,
vocab=vocab,
params=params,
)
- Schreiben Sie einen Prompt in
prompt
und führen Sie eine Inferenz durch. Sie könnentotal_generation_steps
optimieren (die Anzahl der Schritte, die beim Generieren einer Antwort ausgeführt werden; in diesem Beispiel wird50
verwendet, um den Hostspeicher beizubehalten).
prompt = [
"\n# 5+9=?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=50,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Prompt: # 5+9=? Output: # Answer: 14 # Explanation: 5 + 9 = 14.
Weitere Informationen
- Weitere Informationen zur
recurrentgemma
-Bibliothek von Google DeepMind auf GitHub, die docstrings zu Methoden und Modulen enthält, die Sie in dieser Anleitung verwendet haben, z. B.recurrentgemma.jax.load_parameters
,recurrentgemma.jax.Griffin
undrecurrentgemma.jax.Sampler
. - Die folgenden Bibliotheken haben eigene Dokumentationswebsites: Core JAX, Flax und Orbax.
- Eine Dokumentation zu
sentencepiece
Tokenizer/Detokenizer finden Sie im GitHub-Repository von Googlesentencepiece
. - Die
kagglehub
-Dokumentation finden Sie unterREADME.md
imkagglehub
-GitHub-Repository von Kaggle. - Gemma-Modelle mit Google Cloud Vertex AI verwenden
- Sieh dir RecurrentGemma: Moving Past Transformers“ an for Effiziente Open Language Models von Google DeepMind.
- Lesen Sie den Artikel Griffin: Mixing Gated Linear Recurrences with Local Attention for Effective Language Models von GoogleDeepMind, um mehr über die von RecurrentGemma verwendete Modellarchitektur zu erfahren.