|
|
In Google Colab ausführen
|
|
Quelle auf GitHub ansehen
|
In dieser Anleitung erfahren Sie, wie Sie mit dem RecurrentGemma 2B Instruct-Modell grundlegende Stichproben/Inferenzen mit der recurrentgemma-Bibliothek von Google DeepMind durchführen, die mit JAX (einer leistungsstarken numerischen Rechenbibliothek), Flax (der JAX-basierten neuronalen Netzwerkbibliothek), Orbax (einer JAX-basierten Bibliothek für Trainingsprogramme) und {1dePieSence-Token-Token1}-Bibliothek (Checkpointing-Bibliothek) geschrieben wurde.SentencePiece Obwohl Flax in diesem Notizbuch nicht direkt verwendet wird, wurde Flax verwendet, um Gemma und RecurrentGemma (das Griffin-Modell) zu erstellen.
Dieses Notebook kann in Google Colab mit der T4-GPU ausgeführt werden. Rufen Sie dazu Bearbeiten > Notebook-Einstellungen > Hardwarebeschleuniger auf und wählen Sie T4-GPU aus.
Einrichtung
In den folgenden Abschnitten werden die Schritte zum Vorbereiten eines Notebooks für die Verwendung eines RecurrentGemma-Modells erläutert, einschließlich des Modellzugriffs, des Abrufens eines API-Schlüssels und des Konfigurierens der Notebook-Laufzeit
Kaggle-Zugriff für Gemma einrichten
Um diese Anleitung abzuschließen, müssen Sie zuerst der Einrichtungsanleitung folgen, die der Gemma-Einrichtung ähnlich ist. Es gibt allerdings einige Ausnahmen:
- Auf kaggle.com erhältst du Zugriff auf RecurrentGemma (anstelle von Gemma).
- Wählen Sie eine Colab-Laufzeit mit ausreichenden Ressourcen zum Ausführen des RecurrentGemma-Modells aus.
- Generieren und konfigurieren Sie einen Kaggle-Nutzernamen und einen API-Schlüssel.
Nachdem Sie die Einrichtung von RecurrentGemma abgeschlossen haben, fahren Sie mit dem nächsten Abschnitt fort. Dort legen Sie Umgebungsvariablen für Ihre Colab-Umgebung fest.
Umgebungsvariablen festlegen
Legen Sie Umgebungsvariablen für KAGGLE_USERNAME und KAGGLE_KEY fest. Wenn die Aufforderung „Zugriff erlauben?“ angezeigt wird, -Nachrichten, erklären Sie sich damit einverstanden, Secret-Zugriff bereitzustellen.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
recurrentgemma-Bibliothek installieren
In diesem Notebook wird eine kostenlose Colab-GPU verwendet. Klicken Sie zum Aktivieren der Hardwarebeschleunigung auf Bearbeiten > Notebook-Einstellungen > Wählen Sie T4 GPU aus > Klicken Sie auf Speichern.
Als Nächstes müssen Sie die recurrentgemma-Bibliothek von Google DeepMind von github.com/google-deepmind/recurrentgemma installieren. Wenn Sie einen Fehler zum „Abhängigkeitsauflöser von pip“ erhalten, können Sie ihn in der Regel ignorieren.
pip install git+https://github.com/google-deepmind/recurrentgemma.git
RecurrentGemma-Modell laden und vorbereiten
- Laden Sie das RecurrentGemma-Modell mit
kagglehub.model_download. Dafür werden drei Argumente benötigt:
handle: Das Modell-Handle von Kagglepath: (optionaler String) der lokale Pfadforce_download: (optionaler boolescher Wert) Erzwingt das erneute Herunterladen des Modells
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s] Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
- Überprüfen Sie den Speicherort der Modellgewichtungen und des Tokenizers und legen Sie dann die Pfadvariablen fest. Das Tokenizer-Verzeichnis befindet sich im Hauptverzeichnis, in das Sie das Modell heruntergeladen haben, und die Modellgewichtungen befinden sich in einem Unterverzeichnis. Beispiel:
- Die Datei
tokenizer.modelbefindet sich in/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1. - Der Modellprüfpunkt befindet sich in
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it.
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
Stichprobenerhebung/Inferenz durchführen
- Laden Sie den Prüfpunkt des RecurrentGemma-Modells mit der Methode
recurrentgemma.jax.load_parameters. Das auf"single_device"gesetzte Argumentshardinglädt alle Modellparameter auf ein einzelnes Gerät.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma
params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
- Laden Sie den Modelltokenizer RecurrentGemma, der mit
sentencepiece.SentencePieceProcessorerstellt wurde:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- Verwenden Sie
recurrentgemma.GriffinConfig.from_flax_params_or_variables, um automatisch die richtige Konfiguration aus dem RecurrentGemma-Modellprüfpunkt zu laden. Instanziieren Sie dann das Griffin-Modell mitrecurrentgemma.jax.Griffin.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
flax_params_or_variables=params)
model = recurrentgemma.Griffin(model_config)
- Erstellen Sie eine
samplermitrecurrentgemma.jax.Samplerzusätzlich zum Prüfpunkt bzw. die Gewichte des RecurrentGemma-Modells und dem Tokenizer:
sampler = recurrentgemma.Sampler(
model=model,
vocab=vocab,
params=params,
)
- Schreiben Sie einen Prompt in
promptund führen Sie eine Inferenz durch. Sie könnentotal_generation_stepsoptimieren (die Anzahl der Schritte, die beim Generieren einer Antwort ausgeführt werden; in diesem Beispiel wird50verwendet, um den Hostspeicher beizubehalten).
prompt = [
"\n# 5+9=?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=50,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
warnings.warn("Some donated buffers were not usable:"
Prompt:
# 5+9=?
Output:
# Answer: 14
# Explanation: 5 + 9 = 14.
Weitere Informationen
- Weitere Informationen zur
recurrentgemma-Bibliothek von Google DeepMind auf GitHub, die docstrings zu Methoden und Modulen enthält, die Sie in dieser Anleitung verwendet haben, z. B.recurrentgemma.jax.load_parameters,recurrentgemma.jax.Griffinundrecurrentgemma.jax.Sampler. - Die folgenden Bibliotheken haben eigene Dokumentationswebsites: Core JAX, Flax und Orbax.
- Eine Dokumentation zu
sentencepieceTokenizer/Detokenizer finden Sie im GitHub-Repository von Googlesentencepiece. - Die
kagglehub-Dokumentation finden Sie unterREADME.mdimkagglehub-GitHub-Repository von Kaggle. - Gemma-Modelle mit Google Cloud Vertex AI verwenden
- Sieh dir RecurrentGemma: Moving Past Transformers“ an for Effiziente Open Language Models von Google DeepMind.
- Lesen Sie den Artikel Griffin: Mixing Gated Linear Recurrences with Local Attention for Effective Language Models von GoogleDeepMind, um mehr über die von RecurrentGemma verwendete Modellarchitektur zu erfahren.
In Google Colab ausführen
Quelle auf GitHub ansehen