Inferenz mit RecurrentGemma

Auf ai.google.dev ansehen In Google Colab ausführen In Vertex AI öffnen Quelle auf GitHub ansehen

In dieser Anleitung erfahren Sie, wie Sie mit dem RecurrentGemma 2B Instruct-Modell grundlegende Stichproben/Inferenzen mit der recurrentgemma-Bibliothek von Google DeepMind durchführen, die mit JAX (einer leistungsstarken numerischen Rechenbibliothek), Flax (der JAX-basierten neuronalen Netzwerkbibliothek), Orbax (einer JAX-basierten Bibliothek für Trainingsprogramme) und {1dePieSence-Token-Token1}-Bibliothek (Checkpointing-Bibliothek) geschrieben wurde.SentencePiece Obwohl Flax in diesem Notizbuch nicht direkt verwendet wird, wurde Flax verwendet, um Gemma und RecurrentGemma (das Griffin-Modell) zu erstellen.

Dieses Notebook kann in Google Colab mit der T4-GPU ausgeführt werden. Rufen Sie dazu Bearbeiten > Notebook-Einstellungen > Hardwarebeschleuniger auf und wählen Sie T4-GPU aus.

Einrichtung

In den folgenden Abschnitten werden die Schritte zum Vorbereiten eines Notebooks für die Verwendung eines RecurrentGemma-Modells erläutert, einschließlich des Modellzugriffs, des Abrufens eines API-Schlüssels und des Konfigurierens der Notebook-Laufzeit

Kaggle-Zugriff für Gemma einrichten

Um diese Anleitung abzuschließen, müssen Sie zuerst der Einrichtungsanleitung folgen, die der Gemma-Einrichtung ähnlich ist. Es gibt allerdings einige Ausnahmen:

  • Auf kaggle.com erhältst du Zugriff auf RecurrentGemma (anstelle von Gemma).
  • Wählen Sie eine Colab-Laufzeit mit ausreichenden Ressourcen zum Ausführen des RecurrentGemma-Modells aus.
  • Generieren und konfigurieren Sie einen Kaggle-Nutzernamen und einen API-Schlüssel.

Nachdem Sie die Einrichtung von RecurrentGemma abgeschlossen haben, fahren Sie mit dem nächsten Abschnitt fort. Dort legen Sie Umgebungsvariablen für Ihre Colab-Umgebung fest.

Umgebungsvariablen festlegen

Legen Sie Umgebungsvariablen für KAGGLE_USERNAME und KAGGLE_KEY fest. Wenn die Meldung „Zugriff gewähren?“ angezeigt wird, stimmen Sie zu, Secret-Zugriff zu gewähren.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

recurrentgemma-Bibliothek installieren

In diesem Notebook wird eine kostenlose Colab-GPU verwendet. Klicken Sie auf Bearbeiten > Notebook-Einstellungen > T4-GPU > Speichern, um die Hardwarebeschleunigung zu aktivieren.

Als Nächstes müssen Sie die recurrentgemma-Bibliothek von Google DeepMind von github.com/google-deepmind/recurrentgemma installieren. Wenn Sie eine Fehlermeldung zum „Abhängigkeitsauflöser von pip“ erhalten, können Sie ihn in der Regel ignorieren.

pip install git+https://github.com/google-deepmind/recurrentgemma.git

RecurrentGemma-Modell laden und vorbereiten

  1. Laden Sie das RecurrentGemma-Modell mit kagglehub.model_download. Dafür werden drei Argumente benötigt:
  • handle: Das Modell-Handle von Kaggle
  • path: (optionaler String) der lokale Pfad
  • force_download: (optionaler boolescher Wert) Erzwingt das erneute Herunterladen des Modells
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. Prüfen Sie den Speicherort der Modellgewichtungen und des Tokenizers und legen Sie dann die Pfadvariablen fest. Das Tokenizer-Verzeichnis befindet sich im Hauptverzeichnis, in das Sie das Modell heruntergeladen haben, und die Modellgewichtungen befinden sich in einem Unterverzeichnis. Beispiel:
  • Die Datei tokenizer.model befindet sich in /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1.
  • Der Modellprüfpunkt befindet sich in /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it.
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

Stichprobenerhebung/Inferenz durchführen

  1. Laden Sie den Prüfpunkt des RecurrentGemma-Modells mit der Methode recurrentgemma.jax.load_parameters. Das auf "single_device" gesetzte Argument sharding lädt alle Modellparameter auf ein einzelnes Gerät.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. Laden Sie den Modelltokenizer RecurrentGemma, der mit sentencepiece.SentencePieceProcessor erstellt wurde:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Verwenden Sie recurrentgemma.GriffinConfig.from_flax_params_or_variables, um automatisch die richtige Konfiguration aus dem RecurrentGemma-Modellprüfpunkt zu laden. Instanziieren Sie dann das Griffin-Modell mit recurrentgemma.jax.Griffin.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. Erstellen Sie eine sampler mit recurrentgemma.jax.Sampler zusätzlich zum Prüfpunkt bzw. die Gewichte des RecurrentGemma-Modells und dem Tokenizer:
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. Schreiben Sie einen Prompt in prompt und führen Sie eine Inferenz durch. Sie können total_generation_steps optimieren (die Anzahl der Schritte, die beim Generieren einer Antwort ausgeführt werden; in diesem Beispiel wird 50 verwendet, um den Hostspeicher beizubehalten).
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

Weitere Informationen