Inferenz mit Gemma mithilfe von JAX und Flax

Auf ai.google.dev ansehen In Google Colab ausführen In Vertex AI öffnen Quelle auf GitHub ansehen

Überblick

Gemma ist eine Familie von schlanken, modernen offenen Large Language Models, die auf der Forschung und Technologie von Google DeepMind Gemini basieren. In dieser Anleitung wird gezeigt, wie Sie grundlegende Stichproben/Inferenzen mit dem Gemma 2B-InSTRUCT-Modell mithilfe der gemma-Bibliothek von Google DeepMind durchführen, die mit JAX (einer leistungsstarken numerischen Computing-Bibliothek), Flax (die JAX-basierte neuronale Netzwerkbibliothek), Orbax (eine JAX-basierte Bibliothek für Trainingsdienstprogramme wie Checkpointing) und SentencePiece Flax wird in diesem Notizbuch nicht direkt verwendet, aber aus Flax wurde Gemma hergestellt.

Dieses Notebook kann in Google Colab mit einer kostenlosen T4-GPU ausgeführt werden. Gehen Sie dazu zu Bearbeiten > Notebook-Einstellungen > wählen Sie unter Hardwarebeschleuniger die Option T4-GPU aus.

Einrichtung

1. Kaggle-Zugriff für Gemma einrichten

Um diese Anleitung abzuschließen, folgen Sie zuerst der Einrichtungsanleitung unter Gemma-Einrichtung, die Ihnen folgende Schritte zeigt:

  • Zugriff auf Gemma erhalten Sie unter kaggle.com.
  • Wählen Sie eine Colab-Laufzeit mit ausreichend Ressourcen zum Ausführen des Gemma-Modells aus.
  • Generieren und konfigurieren Sie einen Kaggle-Nutzernamen und einen API-Schlüssel.

Fahren Sie nach Abschluss der Gmma-Einrichtung mit dem nächsten Abschnitt fort, in dem Sie Umgebungsvariablen für Ihre Colab-Umgebung festlegen.

2. Umgebungsvariablen festlegen

Legen Sie Umgebungsvariablen für KAGGLE_USERNAME und KAGGLE_KEY fest. Wenn die Meldung „Zugriff gewähren?“ angezeigt wird, stimmen Sie der Bereitstellung des Secret-Zugriffs zu.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. gemma-Bibliothek installieren

Dieses Notebook konzentriert sich auf die Verwendung einer kostenlosen Colab-GPU. Klicken Sie zum Aktivieren der Hardwarebeschleunigung auf Bearbeiten > Notebook-Einstellungen > T4-GPU > Speichern.

Als Nächstes müssen Sie die gemma-Bibliothek von Google DeepMind von github.com/google-deepmind/gemma installieren. Wenn Sie eine Fehlermeldung zum „Abhängigkeitsauflöser von pip“ erhalten, können Sie diese in der Regel ignorieren.

pip install -q git+https://github.com/google-deepmind/gemma.git

Gemma-Modell laden und vorbereiten

  1. Laden Sie das Gemma-Modell mit kagglehub.model_download, das drei Argumente verwendet:
  • handle: der Modell-Handle von Kaggle
  • path: (optionaler String) Der lokale Pfad
  • force_download: (optionaler boolescher Wert) Erzwingt einen erneuten Download des Modells
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:35<00:00, 110MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
  1. Prüfen Sie den Speicherort der Modellgewichtungen und des Tokenizers und legen Sie dann die Pfadvariablen fest. Das Tokenizer-Verzeichnis befindet sich im Hauptverzeichnis, in das Sie das Modell heruntergeladen haben, während sich die Modellgewichtungen in einem Unterverzeichnis befinden. Beispiel:
  • Die Datei „tokenizer.model“ wird das Format /LOCAL/PATH/TO/gemma/flax/2b-it/2 haben.
  • Der Modell-Checkpoint befindet sich in /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it.
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

Stichproben/Inferenz durchführen

  1. Laden und formatieren Sie den Prüfpunkt für das Gemma-Modell mit der Methode gemma.params.load_and_format_params:
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)
  1. Laden Sie den Gemma-Tokenizer, der mit sentencepiece.SentencePieceProcessor erstellt wurde:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Verwenden Sie gemma.transformer.TransformerConfig, um automatisch die richtige Konfiguration aus dem Gemma-Modellprüfpunkt zu laden. Das Argument cache_size ist die Anzahl der Zeitschritte im Gemma-Transformer-Cache. Instanziieren Sie anschließend das Gemma-Modell als transformer mit gemma.transformer.Transformer (übernimmt von flax.linen.Module).
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)
  1. Erstellen Sie eine sampler mit gemma.sampler.Sampler zusätzlich zum Checkpoint/Gewichtung des Gemma-Modells und des Tokenizers:
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)
  1. Schreiben Sie einen Prompt in input_batch und führen Sie eine Inferenz durch. Sie können total_generation_steps optimieren, d. h. die Anzahl der Schritte, die beim Generieren einer Antwort ausgeführt werden. In diesem Beispiel wird 100 verwendet, um den Hostarbeitsspeicher zu erhalten.
prompt = [
    "\n# What is the meaning of life?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=100,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
Prompt:

# What is the meaning of life?
Output:


The question of what the meaning of life is one that has occupied the minds of philosophers, theologians, and individuals for centuries. There is no single, universally accepted answer, but there are many different perspectives on this complex and multifaceted question.

**Some common perspectives on the meaning of life include:**

* **Biological perspective:** From a biological standpoint, the meaning of life is to survive and reproduce.
* **Existential perspective:** Existentialists believe that life is not inherently meaningful and that
  1. Optional: Führen Sie diese Zelle aus, um Arbeitsspeicher freizugeben, wenn Sie das Notebook fertiggestellt haben und eine weitere Eingabeaufforderung ausprobieren möchten. Danach können Sie die sampler in Schritt 3 noch einmal instanziieren und die Eingabeaufforderung in Schritt 4 anpassen und ausführen.
del sampler

Weitere Informationen