Inferenzen mit Gemma und Keras ausführen

Auf ai.google.dev ansehen In Google Colab ausführen In Vertex AI öffnen Quelle auf GitHub ansehen

In dieser Anleitung erfahren Sie, wie Sie Gemma mit KerasNLP verwenden, um Inferenzen auszuführen und Text zu generieren. Gemma ist eine Familie von leichten, hochmodernen offenen Modellen, die auf derselben Forschung und Technologie basieren, die auch für die Erstellung der Gemini-Modelle verwendet wurden. KerasNLP ist eine Sammlung von NLP-Modellen (Natural Language Processing), die in Keras implementiert und auf JAX, PyTorch und TensorFlow ausführbar sind.

In dieser Anleitung verwenden Sie Gemma, um Textantworten auf mehrere Prompts zu generieren. Wenn Sie Keras zum ersten Mal verwenden, sollten Sie den Artikel Erste Schritte mit Keras lesen. Dies ist jedoch nicht zwingend erforderlich. Im Verlauf dieses Tutorials erfahren Sie mehr über Keras.

Einrichtung

Gemma-Einrichtung

Um diese Anleitung abzuschließen, müssen Sie zuerst die Schritte unter Gemma-Einrichtung ausführen. In der Anleitung zur Einrichtung von Gemma erfahren Sie, wie Sie Folgendes tun können:

  • Auf kaggle.com erhältst du Zugriff auf Gemma.
  • Wählen Sie eine Colab-Laufzeit mit ausreichenden Ressourcen zum Ausführen des Gemma 2B-Modells aus.
  • Generieren und konfigurieren Sie einen Kaggle-Nutzernamen und einen API-Schlüssel.

Nachdem Sie die Gemma-Einrichtung abgeschlossen haben, fahren Sie mit dem nächsten Abschnitt fort. Dort legen Sie Umgebungsvariablen für Ihre Colab-Umgebung fest.

Umgebungsvariablen festlegen

Legen Sie Umgebungsvariablen für KAGGLE_USERNAME und KAGGLE_KEY fest.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Abhängigkeiten installieren

Installieren Sie Keras und KerasNLP.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U keras>=3

Backend auswählen

Keras ist eine Deep-Learning-API auf hoher Ebene mit mehreren Frameworks, die auf einfache und nutzerfreundliche Weise entwickelt wurde. Mit Keras 3 können Sie das Back-End auswählen: TensorFlow, JAX oder PyTorch. Alle drei funktionieren in dieser Anleitung.

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Pakete importieren

Importieren Sie Keras und KerasNLP.

import keras
import keras_nlp

Modell erstellen

KerasNLP bietet Implementierungen vieler beliebter Modellarchitekturen. In dieser Anleitung erstellen Sie ein Modell mit GemmaCausalLM, einem End-to-End-Gemma-Modell für kausale Language Models. Ein kausales Sprachmodell sagt das nächste Token basierend auf vorherigen Tokens voraus.

Erstellen Sie das Modell mit der Methode from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...

Die Funktion GemmaCausalLM.from_preset() instanziiert das Modell aus einer voreingestellten Architektur und Gewichtungen. Im Code oben gibt der String "gemma_2b_en" die Voreinstellung für das Gemma 2B-Modell mit 2 Milliarden Parametern an. Gemma-Modelle mit den Parametern 7B, 9B und 27B sind ebenfalls verfügbar. Sie finden die Codestrings für Gemma-Modelle in der Liste der Modellvarianten auf kaggle.com.

Verwenden Sie summary, um weitere Informationen zum Modell zu erhalten:

gemma_lm.summary()

Wie Sie der Zusammenfassung entnehmen können, hat das Modell 2,5 Milliarden trainierbare Parameter.

Text generieren

Jetzt ist es an der Zeit, Text zu generieren. Das Modell hat eine generate-Methode, die Text basierend auf einem Prompt generiert. Das optionale Argument max_length gibt die maximale Länge der generierten Sequenz an.

Probieren Sie es mit dem Prompt "What is the meaning of life?" aus.

gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives'

Versuche, generate mit einem anderen Prompt noch einmal aufzurufen.

gemma_lm.generate("How does the brain work?", max_length=64)
'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up'

Wenn Sie JAX- oder TensorFlow-Back-Ends verwenden, werden Sie feststellen, dass der zweite generate-Aufruf fast sofort zurückgegeben wird. Das liegt daran, dass jeder Aufruf von generate für eine bestimmte Batchgröße und max_length mit XLA kompiliert wird. Der erste Durchlauf ist teuer, nachfolgende Durchläufe sind jedoch viel schneller.

Sie können auch stapelweise Prompts mithilfe einer Liste als Eingabe bereitstellen:

gemma_lm.generate(
    ["What is the meaning of life?",
     "How does the brain work?"],
    max_length=64)
['What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives',
 'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up']

Optional: Anderen Sampler ausprobieren

Sie können die Generierungsstrategie für GemmaCausalLM steuern, indem Sie das Argument sampler für compile() festlegen. Standardmäßig wird die Stichprobenerhebung "greedy" verwendet.

Legen Sie probeweise eine "top_k"-Strategie fest:

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life? That was a question I asked myself as I was driving home from work one night in 2012. I was driving through the city of San Bernardino, and all I could think was, “What the heck am I doing?”\n\nMy life was completely different. I'

Während der Standard-Gieredy-Algorithmus immer das Token mit der größten Wahrscheinlichkeit auswählt, wählt der Top-K-Algorithmus zufällig das nächste Token aus den Tokens mit der Top-K-Wahrscheinlichkeit aus.

Sie müssen keinen Sampler angeben und können das letzte Code-Snippet ignorieren, wenn es für Ihren Anwendungsfall nicht hilfreich ist. Weitere Informationen zu den verfügbaren Samplern finden Sie hier.

Nächste Schritte

In dieser Anleitung haben Sie gelernt, wie Sie mit KerasNLP und Gemma Text generieren. Hier ein paar Vorschläge, was Sie als Nächstes lernen können: