Inferenz mit RecurrentGemma unter Verwendung von JAX und Flax

Auf ai.google.dev ansehen In Google Colab ausführen In Vertex AI öffnen Quelle auf GitHub ansehen

In dieser Anleitung erfahren Sie, wie Sie mit dem RecurrentGemma 2B Instruct-Modell grundlegende Stichproben/Inferenzen mit der recurrentgemma-Bibliothek von Google DeepMind durchführen, die mit JAX (einer leistungsstarken numerischen Rechenbibliothek), Flax (der JAX-basierten neuronalen Netzwerkbibliothek), Orbax (einer JAX-basierten Bibliothek für Trainingsprogramme) und {1dePieSence-Token-Token1}-Bibliothek (Checkpointing-Bibliothek) geschrieben wurde.SentencePiece Obwohl Flax in diesem Notizbuch nicht direkt verwendet wird, wurde Flax verwendet, um Gemma und RecurrentGemma (das Griffin-Modell) zu erstellen.

Dieses Notebook kann in Google Colab mit der T4-GPU ausgeführt werden. Rufen Sie dazu Bearbeiten > Notebook-Einstellungen > Hardwarebeschleuniger auf und wählen Sie T4-GPU aus.

Einrichtung

In den folgenden Abschnitten werden die Schritte zum Vorbereiten eines Notebooks für die Verwendung eines RecurrentGemma-Modells erläutert, einschließlich des Modellzugriffs, des Abrufens eines API-Schlüssels und des Konfigurierens der Notebook-Laufzeit

Kaggle-Zugriff für Gemma einrichten

Um diese Anleitung abzuschließen, müssen Sie zuerst der Einrichtungsanleitung folgen, die der Gemma-Einrichtung ähnlich ist. Es gibt allerdings einige Ausnahmen:

  • Auf kaggle.com erhältst du Zugriff auf RecurrentGemma (anstelle von Gemma).
  • Wählen Sie eine Colab-Laufzeit mit ausreichenden Ressourcen zum Ausführen des RecurrentGemma-Modells aus.
  • Generieren und konfigurieren Sie einen Kaggle-Nutzernamen und einen API-Schlüssel.

Nachdem Sie die Einrichtung von RecurrentGemma abgeschlossen haben, fahren Sie mit dem nächsten Abschnitt fort. Dort legen Sie Umgebungsvariablen für Ihre Colab-Umgebung fest.

Umgebungsvariablen festlegen

Legen Sie Umgebungsvariablen für KAGGLE_USERNAME und KAGGLE_KEY fest. Wenn die Aufforderung „Zugriff erlauben?“ angezeigt wird, -Nachrichten, erklären Sie sich damit einverstanden, Secret-Zugriff bereitzustellen.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

recurrentgemma-Bibliothek installieren

In diesem Notebook wird eine kostenlose Colab-GPU verwendet. Klicken Sie zum Aktivieren der Hardwarebeschleunigung auf Bearbeiten > Notebook-Einstellungen > Wählen Sie T4 GPU aus > Klicken Sie auf Speichern.

Als Nächstes müssen Sie die recurrentgemma-Bibliothek von Google DeepMind von github.com/google-deepmind/recurrentgemma installieren. Wenn Sie einen Fehler zum „Abhängigkeitsauflöser von pip“ erhalten, können Sie ihn in der Regel ignorieren.

pip install git+https://github.com/google-deepmind/recurrentgemma.git

RecurrentGemma-Modell laden und vorbereiten

  1. Laden Sie das RecurrentGemma-Modell mit kagglehub.model_download. Dafür werden drei Argumente benötigt:
  • handle: Das Modell-Handle von Kaggle
  • path: (optionaler String) der lokale Pfad
  • force_download: (optionaler boolescher Wert) Erzwingt das erneute Herunterladen des Modells
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. Überprüfen Sie den Speicherort der Modellgewichtungen und des Tokenizers und legen Sie dann die Pfadvariablen fest. Das Tokenizer-Verzeichnis befindet sich im Hauptverzeichnis, in das Sie das Modell heruntergeladen haben, und die Modellgewichtungen befinden sich in einem Unterverzeichnis. Beispiel:
  • Die Datei tokenizer.model befindet sich in /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1.
  • Der Modellprüfpunkt befindet sich in /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it.
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

Stichprobenerhebung/Inferenz durchführen

  1. Laden Sie den Prüfpunkt des RecurrentGemma-Modells mit der Methode recurrentgemma.jax.load_parameters. Das auf "single_device" gesetzte Argument sharding lädt alle Modellparameter auf ein einzelnes Gerät.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. Laden Sie den Modelltokenizer RecurrentGemma, der mit sentencepiece.SentencePieceProcessor erstellt wurde:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Verwenden Sie recurrentgemma.GriffinConfig.from_flax_params_or_variables, um automatisch die richtige Konfiguration aus dem RecurrentGemma-Modellprüfpunkt zu laden. Instanziieren Sie dann das Griffin-Modell mit recurrentgemma.jax.Griffin.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. Erstellen Sie eine sampler mit recurrentgemma.jax.Sampler zusätzlich zum Prüfpunkt bzw. die Gewichte des RecurrentGemma-Modells und dem Tokenizer:
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. Schreiben Sie einen Prompt in prompt und führen Sie eine Inferenz durch. Sie können total_generation_steps optimieren (die Anzahl der Schritte, die beim Generieren einer Antwort ausgeführt werden; in diesem Beispiel wird 50 verwendet, um den Hostspeicher beizubehalten).
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

Weitere Informationen