Verteilte Feinabstimmung mit Gemma unter Verwendung von Keras

Auf ai.google.dev ansehen In Google Colab ausführen In Kaggle ausführen In Vertex AI öffnen Quelle auf GitHub ansehen

Übersicht

Gemma ist eine Familie leichtgewichtiger, hochmoderner offener Modelle, die auf Forschung und Technologie basieren, um Google Gemini-Modelle zu erstellen. Gemma kann weiter auf spezifische Anforderungen abgestimmt werden. Large Language Models wie Gemma können jedoch sehr groß sein und einige davon passen möglicherweise nicht auf einen Sing-Beschleuniger für die Feinabstimmung. In diesem Fall gibt es zwei allgemeine Ansätze zur Optimierung:

  1. Parameter Effiziente Feinabstimmung (PEFT), die versucht, die effektive Modellgröße durch Abstriche bei der Genauigkeit zu verkleinern. LoRA fällt in diese Kategorie. In der Anleitung Gemma-Modelle in Keras mit LoRA optimieren wird gezeigt, wie das Gemma 2B-Modell gemma_2b_en mit LoRA mit KerasNLP auf einer einzigen GPU optimiert wird.
  2. Vollständige Parameteroptimierung mit Modellparallelität. Die Modellparallelität verteilt die Gewichtungen eines einzelnen Modells auf mehrere Geräte und ermöglicht die horizontale Skalierung. Weitere Informationen zum verteilten Training finden Sie in diesem Keras-Leitfaden.

In dieser Anleitung erfahren Sie, wie Sie Keras mit einem JAX-Backend verwenden, um das Gemma 7B-Modell mit LoRA und Modellparallismus auf der Tensor Processing Unit (TPU) von Google zu optimieren. Beachten Sie, dass LoRA in dieser Anleitung deaktiviert werden kann, um eine langsamere, aber präzisere Feinabstimmung der vollständigen Parameter zu erreichen.

Beschleuniger verwenden

Technisch gesehen können Sie für diese Anleitung entweder eine TPU oder eine GPU verwenden.

Hinweise zu TPU-Umgebungen

Google bietet drei Produkte für TPUs an:

  • Colab stellt TPU v2 kostenlos zur Verfügung. Das ist für diese Anleitung ausreichend.
  • Kaggle bietet TPU v3 kostenlos an, die ebenfalls für diese Anleitung geeignet sind.
  • Cloud TPU bietet TPU v3 und neuere Generationen. Eine Möglichkeit zum Einrichten:
    1. Neue TPU-VM erstellen
    2. Richten Sie die SSH-Portweiterleitung für den gewünschten Jupyter-Serverport ein
    3. Jupyter installieren und auf der TPU-VM starten und dann über „Mit lokaler Laufzeit verbinden“ eine Verbindung zu Colab herstellen

Hinweise zur Einrichtung mit mehreren GPUs

Obwohl sich diese Anleitung auf den TPU-Anwendungsfall konzentriert, können Sie ihn einfach an Ihre eigenen Anforderungen anpassen, wenn Sie einen Computer mit mehreren GPUs haben.

Wenn Sie lieber mit Colab arbeiten, können Sie auch direkt über „Verbindung zu einer benutzerdefinierten GCE-VM herstellen“ eine VM mit mehreren GPUs für Colab bereitstellen. im Colab Connect-Menü.

Wir konzentrieren uns hier auf die Verwendung der kostenlosen TPU von Kaggle.

Hinweis

Kaggle-Qualifikationen

Gemma-Modelle werden von Kaggle gehostet. Um Gemma zu verwenden, fordern Sie Zugriff auf Kaggle an:

  • Melde dich an oder registriere dich auf kaggle.com
  • Öffnen Sie die Gemma-Modellkarte und wählen Sie Zugriff anfordern aus.
  • Einwilligungsformular ausfüllen und Nutzungsbedingungen akzeptieren

Um die Kaggle API zu verwenden, erstellen Sie dann ein API-Token:

  • Öffnen Sie die Kaggle-Einstellungen.
  • Wählen Sie Create New Token (Neues Token erstellen) aus.
  • Eine kaggle.json-Datei wird heruntergeladen. Es enthält Ihre Kaggle-Anmeldedaten.

Führen Sie die folgende Zelle aus und geben Sie Ihre Kaggle-Anmeldedaten ein, wenn Sie dazu aufgefordert werden.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Alternativ können Sie KAGGLE_USERNAME und KAGGLE_KEY in Ihrer Umgebung festlegen, wenn kagglehub.login() nicht funktioniert.

Installation

Installieren Sie Keras und KerasNLP mit dem Gemma-Modell.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Keras-JAX-Backend einrichten

Importieren Sie JAX und führen Sie eine Plausibilitätsprüfung auf der TPU durch. Kaggle bietet TPUv3-8-Geräte mit 8 TPU-Kernen mit jeweils 16 GB Arbeitsspeicher.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Modell laden

import keras
import keras_nlp

Hinweise zum Mixed Precision-Training auf NVIDIA-GPUs

Beim Training auf NVIDIA-GPUs kann mit gemischter Präzision (keras.mixed_precision.set_global_policy('mixed_bfloat16')) das Training mit minimaler Auswirkung auf die Trainingsqualität beschleunigt werden. In den meisten Fällen empfiehlt es sich, die Verwendung von gemischter Genauigkeit zu aktivieren, da Sie so Arbeitsspeicher und Zeit sparen. Beachten Sie jedoch, dass die Arbeitsspeichernutzung bei kleinen Batchgrößen um das 1,5-Fache in die Höhe getrieben werden kann (Gewichtungen werden zweimal mit halber und voller Genauigkeit geladen).

Für Inferenz funktioniert die halbe Genauigkeit (keras.config.set_floatx("bfloat16")) und es wird Arbeitsspeicher gespart, während keine gemischte Genauigkeit anwendbar ist.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Zum Laden des Modells mit den Gewichten und Tensoren, die auf TPUs verteilt sind, müssen Sie zuerst eine neue DeviceMesh erstellen. DeviceMesh steht für eine Sammlung von Hardwaregeräten, die für verteilte Berechnungen konfiguriert sind, und wurde in Keras 3 als Teil der Unified Distribution API eingeführt.

Die Distribution API ermöglicht Daten- und Modellparallelität und ermöglicht eine effiziente Skalierung von Deep-Learning-Modellen auf mehreren Beschleunigern und Hosts. Es nutzt das zugrunde liegende Framework (z.B. JAX), um das Programm und die Tensoren gemäß den Sharding-Anweisungen mithilfe eines Verfahrens zu verteilen, das als Single Program, Multiple Data (SPMD) Erweiterung bezeichnet wird. Weitere Informationen finden Sie im neuen Keras 3 Distribution API-Leitfaden.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

LayoutMap aus der Distribution API gibt an, wie die Gewichtungen und Tensoren fragmentiert oder repliziert werden sollen. Dabei werden die Stringschlüssel verwendet, z. B. token_embedding/embeddings unten. Diese werden wie ein regulärer Ausdruck behandelt, um Tensor-Pfade abzugleichen. Übereinstimmende Tensoren werden mit Modelldimensionen (8 TPUs) fragmentiert. andere werden vollständig repliziert.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

Mit ModelParallel können Sie Modellgewichtungen oder Aktivierungstensoren für alle Abweichungen im DeviceMesh fragmentieren. In diesem Fall sind einige der Gemma 7B-Modellgewichtungen gemäß der oben definierten layout_map auf 8 TPU-Chips aufgeteilt. Laden Sie das Modell nun verteilt.

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

Prüfen Sie nun, ob das Modell korrekt partitioniert wurde. Nehmen wir decoder_block_1 als Beispiel.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

Inferenz vor der Feinabstimmung

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

Das Modell generiert eine Liste großartiger Komödien aus den 1990er-Jahren. Jetzt optimieren wir das Gemma-Modell, um den Ausgabestil zu ändern.

Abstimmung mit IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

Führen Sie die Feinabstimmung mithilfe der Low-Rank-Anpassung (LoRA) durch. LoRA ist eine Feinabstimmungstechnik, die die Anzahl der trainierbaren Parameter für nachgelagerte Aufgaben erheblich reduziert, indem die gesamten Gewichte des Modells eingefroren und eine kleinere Anzahl neuer trainierbarer Gewichte in das Modell eingefügt werden. Im Grunde parametrisiert LoRA die größeren Vollgewichtsmatrizen um zwei kleinere niedrigrangige Matrizen AxB, um zu trainieren, und diese Technik macht das Training viel schneller und speichereffizienter.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

Beachten Sie, dass durch die Aktivierung von LoRA die Anzahl der trainierbaren Parameter erheblich reduziert wird – von 7 Milliarden auf nur 11 Millionen.

Inferenz nach der Feinabstimmung

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

Nach der Feinabstimmung hat das Modell den Stil von Filmrezensionen gelernt und generiert nun Ausgaben in diesem Stil im Kontext von Komödien der 90er.

Nächste Schritte

In diesem Tutorial haben Sie gelernt, wie Sie mit dem KerasNLP JAX-Backend ein Gemma-Modell für das IMDb-Dataset auf verteilten leistungsstarken TPUs optimieren. Hier noch ein paar Vorschläge, was Sie sonst noch lernen können: