Erste Schritte mit Gemma und KerasNLP

Auf ai.google.dev ansehen In Google Colab ausführen In Vertex AI öffnen Quelle auf GitHub ansehen

In dieser Anleitung werden die ersten Schritte mit Gemma mithilfe von KerasNLP beschrieben. Gemma ist eine Familie leichter, hochmoderner offener Modelle, die auf derselben Forschung und Technologie basieren, die auch für die Gemini-Modelle verwendet werden. KerasNLP ist eine Sammlung von NLP-Modellen (Natural Language Processing), die in Keras implementiert und auf JAX, PyTorch und TensorFlow ausgeführt werden können.

In dieser Anleitung verwenden Sie Gemma, um Textantworten auf verschiedene Prompts zu generieren. Wenn Sie Keras noch nicht kennen, sollten Sie zuerst den Artikel Erste Schritte mit Keras lesen. Dies ist jedoch nicht zwingend erforderlich. Sie erfahren mehr über Keras, während Sie diese Anleitung durcharbeiten.

Einrichtung

Gemma-Einrichtung

Um diese Anleitung abzuschließen, müssen Sie zuerst die Einrichtungsanleitung unter Gemma-Einrichtung ausführen. In der Anleitung zur Einrichtung von Gemma finden Sie folgende Anleitungen:

  • Du erhältst Zugriff auf Gemma auf kaggle.com.
  • Wählen Sie eine Colab-Laufzeit mit genügend Ressourcen zum Ausführen des Gemma 2B-Modells aus.
  • Generieren und konfigurieren Sie einen Kaggle-Nutzernamen und einen API-Schlüssel.

Fahren Sie nach Abschluss der Gmma-Einrichtung mit dem nächsten Abschnitt fort, in dem Sie Umgebungsvariablen für Ihre Colab-Umgebung festlegen.

Umgebungsvariablen festlegen

Legen Sie Umgebungsvariablen für KAGGLE_USERNAME und KAGGLE_KEY fest.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Abhängigkeiten installieren

Installieren Sie Keras und KerasNLP.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U keras>=3

Backend auswählen

Keras ist eine High-Level-Deep-Learning-API mit mehreren Frameworks, die auf Einfachheit und Nutzerfreundlichkeit ausgelegt ist. Bei Keras 3 können Sie das Back-End auswählen: TensorFlow, JAX oder PyTorch. Für diese Anleitung sind alle drei geeignet.

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Pakete importieren

Importieren Sie Keras und KerasNLP.

import keras
import keras_nlp

Modell erstellen

KerasNLP bietet Implementierungen vieler beliebter Modellarchitekturen. In dieser Anleitung erstellen Sie ein Modell mit GemmaCausalLM, einem End-to-End-Gemma-Modell für die kausale Sprachmodellierung. Ein kausales Sprachmodell sagt das nächste Token basierend auf vorherigen Tokens voraus.

Erstellen Sie das Modell mit der Methode from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...

from_preset instanziiert das Modell aus einer voreingestellten Architektur und Gewichtungen. Im Code oben gibt der String "gemma_2b_en" die voreingestellte Architektur an: ein Gemma-Modell mit 2 Milliarden Parametern.

Verwenden Sie summary, um weitere Informationen zum Modell zu erhalten:

gemma_lm.summary()

Wie Sie der Zusammenfassung entnehmen können, hat das Modell 2,5 Milliarden trainierbare Parameter.

Text generieren

Jetzt ist es an der Zeit, Text zu generieren! Das Modell hat die Methode generate, die Text basierend auf einem Prompt generiert. Das optionale Argument max_length gibt die maximale Länge der generierten Sequenz an.

Probieren Sie es mit der Aufforderung "What is the meaning of life?" aus.

gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives'

Versuche, generate mit einer anderen Aufforderung noch einmal aufzurufen.

gemma_lm.generate("How does the brain work?", max_length=64)
'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up'

Wenn Sie JAX- oder TensorFlow-Back-Ends ausführen, werden Sie feststellen, dass der zweite generate-Aufruf nahezu sofort zurückgegeben wird. Das liegt daran, dass jeder Aufruf von generate für eine bestimmte Batchgröße und max_length mit XLA kompiliert wird. Der erste Durchlauf ist teuer, aber nachfolgende Durchläufe ist viel schneller.

Sie können auch aufeinanderfolgende Aufforderungen mithilfe einer Liste als Eingabe bereitstellen:

gemma_lm.generate(
    ["What is the meaning of life?",
     "How does the brain work?"],
    max_length=64)
['What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives',
 'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up']

Optional: Anderen Sampler ausprobieren

Sie können die Generierungsstrategie für GemmaCausalLM steuern, indem Sie das Argument sampler auf compile() festlegen. Standardmäßig wird eine Stichprobe von "greedy" verwendet.

Versuchen Sie, eine "top_k"-Strategie festzulegen:

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life? That was a question I asked myself as I was driving home from work one night in 2012. I was driving through the city of San Bernardino, and all I could think was, “What the heck am I doing?”\n\nMy life was completely different. I'

Während der Standard-Greedy-Algorithmus immer das Token mit der größten Wahrscheinlichkeit auswählt, wählt der Top-K-Algorithmus das nächste Token nach dem Zufallsprinzip aus den Tokens mit der höchsten K-Wahrscheinlichkeit aus.

Sie müssen keinen Sampler angeben und können das letzte Code-Snippet ignorieren, wenn es für Ihren Anwendungsfall nicht hilfreich ist. Weitere Informationen zu den verfügbaren Samplern finden Sie hier.

Nächste Schritte

In dieser Anleitung haben Sie gelernt, wie Sie mit KerasNLP und Gemma Text generieren. Hier sind ein paar Vorschläge, was Sie als Nächstes lernen können: