Verteilte Feinabstimmung mit Gemma unter Verwendung von Keras

Auf ai.google.dev ansehen In Kaggle ausführen In Vertex AI öffnen Quelle auf GitHub ansehen

Überblick

Gemma ist eine Familie von leichten, hochmodernen offenen Modellen, die auf der Forschung und Technologie basieren, die zum Erstellen von Google Gemini-Modellen verwendet werden. Gemma kann weiter auf ihre spezifischen Bedürfnisse abgestimmt werden. Large Language Models wie Gemma können jedoch sehr groß sein und einige von ihnen passen für die Feinabstimmung möglicherweise nicht auf einen Singbeschleuniger. In diesem Fall gibt es zwei allgemeine Ansätze zur Feinabstimmung:

  1. Parameter-Effiziente Feinabstimmung (PEFT), mit der die effektive Modellgröße durch Abstriche bei der Genauigkeit verringert werden soll. LoRA fällt in diese Kategorie und die Anleitung Gemma-Modelle in Keras mit LoRA optimieren zeigt, wie das Gemma 2B-Modell gemma_2b_en mit LoRA mithilfe von KerasNLP auf einer einzelnen GPU optimiert wird.
  2. Vollständige Parameterabstimmung mit Modellparallelität. Die Modellparallelität verteilt die Gewichtung eines einzelnen Modells auf mehrere Geräte und ermöglicht eine horizontale Skalierung. Weitere Informationen zum verteilten Training finden Sie in diesem Keras-Leitfaden.

In dieser Anleitung wird die Verwendung von Keras mit einem JAX-Back-End erläutert, um das Gemma 7B-Modell mit LoRA und dem verteilten Training auf Modellparallismus auf der Tensor Processing Unit (TPU) von Google zu optimieren. Beachten Sie, dass LoRA in dieser Anleitung deaktiviert werden kann, um eine langsamere, aber genauere Einstellung der vollen Parameter zu ermöglichen.

Beschleuniger verwenden

Technisch gesehen können Sie für diese Anleitung entweder eine TPU oder eine GPU verwenden.

Hinweise zu TPU-Umgebungen

Google hat drei Produkte, die TPUs anbieten:

  • Colab bietet TPU v2, das für diese Anleitung nicht ausreicht.
  • Kaggle bietet TPU v3 kostenlos an und funktioniert für diese Anleitung.
  • Cloud TPU bietet TPU v3 und neuere Generationen. Eine Möglichkeit zur Einrichtung:
    1. Neue TPU-VM erstellen
    2. Richten Sie die SSH-Portweiterleitung für den vorgesehenen Jupyter-Serverport ein.
    3. Installieren Sie Jupyter, starten Sie es auf der TPU-VM und stellen Sie dann über „Mit lokaler Laufzeit verbinden“ eine Verbindung zu Colab her

Hinweise zur Einrichtung mit mehreren GPUs

Der Schwerpunkt dieser Anleitung liegt zwar auf dem TPU-Anwendungsfall, Sie können sie aber problemlos an Ihre eigenen Anforderungen anpassen, wenn Sie einen Computer mit mehreren GPUs haben.

Wenn Sie lieber mit Colab arbeiten möchten, können Sie auch eine VM mit mehreren GPUs für Colab direkt über „Verbindung zu einer benutzerdefinierten GCE-VM herstellen“ im Colab Connect-Menü bereitstellen.

Wir werden uns hier auf die Verwendung der kostenlosen TPU von Kaggle konzentrieren.

Hinweis

Kaggle-Qualifikationen

Gemma-Modelle werden von Kaggle gehostet. Wenn Sie Gemma verwenden möchten, fordern Sie Zugriff auf Kaggle an:

  • Anmelden oder registrieren auf kaggle.com
  • Öffnen Sie die Gemma-Modellkarte und wählen Sie Zugriff anfordern aus.
  • Einverständniserklärung ausfüllen und Nutzungsbedingungen akzeptieren

Erstellen Sie anschließend ein API-Token, um die Kaggle-API zu verwenden:

  • Öffnen Sie die Kaggle-Einstellungen.
  • Wählen Sie Create New Token (Neues Token erstellen) aus.
  • Eine kaggle.json-Datei wird heruntergeladen. Es enthält Ihre Kaggle-Qualifikationen.

Führen Sie die folgende Zelle aus und geben Sie Ihre Kaggle-Anmeldedaten ein, wenn Sie dazu aufgefordert werden.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Alternativ können Sie KAGGLE_USERNAME und KAGGLE_KEY in Ihrer Umgebung festlegen, wenn kagglehub.login() bei Ihnen nicht funktioniert.

Installation

Keras und KerasNLP mit dem Gemma-Modell installieren

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Keras-JAX-Backend einrichten

Importieren Sie JAX und führen Sie eine Plausibilitätsprüfung auf der TPU aus. Kaggle bietet TPUv3-8-Geräte mit 8 TPU-Kernen mit jeweils 16 GB Arbeitsspeicher.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Modell laden

import keras
import keras_nlp

Hinweise zum Mixed Precision Training auf NVIDIA-GPUs

Beim Training mit NVIDIA-GPUs kann gemischte Precision (keras.mixed_precision.set_global_policy('mixed_bfloat16')) verwendet werden, um das Training mit minimaler Auswirkung auf die Trainingsqualität zu beschleunigen. In den meisten Fällen wird empfohlen, die gemischte Genauigkeit zu aktivieren, da dadurch sowohl Arbeitsspeicher als auch Zeit gespart wird. Beachten Sie jedoch, dass die Arbeitsspeichernutzung bei kleinen Batchgrößen um das 1,5-Fache erhöht werden kann (Gewichtungen werden zweimal mit halber und voller Genauigkeit geladen).

Bei Inferenz funktioniert eine halbe Genauigkeit (keras.config.set_floatx("bfloat16")) und spart Arbeitsspeicher, während gemischte Genauigkeit nicht anwendbar ist.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Erstellen Sie zuerst eine neue DeviceMesh, um das Modell mit den auf TPUs verteilten Gewichten und Tensoren zu laden. DeviceMesh steht für eine Reihe von Hardwaregeräten, die für verteilte Berechnungen konfiguriert sind und in Keras 3 als Teil der Unified Distribution API eingeführt wurden.

Die Distribution API ermöglicht Daten- und Modellparallelität und ermöglicht eine effiziente Skalierung von Deep-Learning-Modellen auf mehreren Beschleunigern und Hosts. Es nutzt das zugrunde liegende Framework (z. B. JAX), um das Programm und die Tensoren gemäß den Fragmentierungsanweisungen zu verteilen. Dazu wird ein Verfahren namens Single Program, Multiple Data (SPMD) verwendet. Weitere Informationen finden Sie im neuen Leitfaden zur Keras 3 Distribution API.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

LayoutMap aus der Distribution API gibt an, wie die Gewichtungen und Tensoren fragmentiert oder repliziert werden sollen. Dazu werden die Stringschlüssel verwendet, z. B. token_embedding/embeddings unten, die wie Regex behandelt werden, um Tensor-Pfade abzugleichen. Übereinstimmende Tensoren werden mit Modelldimensionen fragmentiert (8 TPUs); andere werden vollständig repliziert.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

Mit ModelParallel können Sie Modellgewichtungen oder Aktivierungstensoren über alle Anwendungen auf dem DeviceMesh fragmentieren. In diesem Fall sind einige der Gewichtungen des Gemma 7B-Modells gemäß der oben definierten layout_map auf 8 TPU-Chips aufgeteilt. Laden Sie das Modell jetzt verteilt.

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

Prüfen Sie nun, ob das Modell korrekt partitioniert wurde. Nehmen wir decoder_block_1 als Beispiel.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

Inferenz vor Feinabstimmung

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

Das Modell generiert eine Liste mit tollen Komödien aus den 90er-Jahren. Jetzt optimieren wir das Gemma-Modell, um den Ausgabestil zu ändern.

Mit IMDB optimieren

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

Führen Sie eine Feinabstimmung mithilfe der Low Rank Adaptation (LoRA) durch. LoRA ist eine Feinabstimmungstechnik, bei der die Anzahl der trainierbaren Parameter für nachgelagerte Aufgaben stark reduziert wird. Dazu werden die vollständigen Gewichtungen des Modells eingefroren und eine kleinere Anzahl neuer trainierbarer Gewichtungen in das Modell eingefügt. Im Grunde parametriert LoRA die größeren Vollgewichtsmatrizen für das Training durch 2 kleinere niedrig eingestufte Matrizen AxB, und diese Technik macht das Training viel schneller und speichereffizienter.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

Beachten Sie, dass durch die Aktivierung von LoRA die Anzahl der trainierbaren Parameter deutlich reduziert wird – von 7 Milliarden auf nur 11 Millionen.

Inferenz nach Feinabstimmung

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

Nach der Feinabstimmung hat das Modell den Stil von Filmrezensionen gelernt und generiert nun Ausgaben in diesem Stil für Komödien der 90er-Jahre.

Nächste Schritte

In diesem Tutorial haben Sie gelernt, wie Sie das KerasNLP JAX-Back-End verwenden können, um ein Gemma-Modell für das IMDb-Dataset verteilt auf den leistungsstarken TPUs zu optimieren. Hier sind ein paar Vorschläge, was Sie sonst noch lernen können: