Auf ai.google.dev ansehen | In Google Colab ausführen | In Kaggle ausführen | In Vertex AI öffnen | Quellcode auf GitHub ansehen |
Übersicht
Gemma ist eine Familie leichter, hochmoderner offener Modelle, die auf der Forschung und Technologie basieren, die auch für die Erstellung der Gemini-Modelle von Google verwendet werden. Gemma kann weiter an spezifische Anforderungen angepasst werden. Large Language Models wie Gemma können jedoch sehr groß sein und einige von ihnen passen möglicherweise nicht auf einen einzelnen Beschleuniger zur Feinabstimmung. In diesem Fall gibt es zwei allgemeine Ansätze zur Feinabstimmung:
- Parametereffiziente Feinabstimmung (PEFT), bei der die effektive Modellgröße durch Einbußen bei der Genauigkeit verringert wird. LoRA fällt in diese Kategorie. In der Anleitung Gemma-Modelle in Keras mit LoRA optimieren wird gezeigt, wie Sie das Gemma 2B-Modell
gemma_2b_en
mit LoRA und KerasNLP auf einer einzelnen GPU optimieren. - Vollständige Parameterfeinabstimmung mit Modellparallelität. Bei der Modellparallelität werden die Gewichte eines einzelnen Modells auf mehrere Geräte verteilt und eine horizontale Skalierung wird ermöglicht. Weitere Informationen zum verteilten Training finden Sie in diesem Keras-Leitfaden.
In dieser Anleitung erfahren Sie, wie Sie Keras mit einem JAX-Backend verwenden, um das Gemma 7B-Modell mit LoRA und verteiltem Training mit Modellparallelismus auf der Tensor Processing Unit (TPU) von Google zu optimieren. Beachten Sie, dass LoRA in dieser Anleitung deaktiviert werden kann, um eine langsamere, aber präzisere Feinabstimmung der vollständigen Parameter zu erreichen.
Beschleuniger verwenden
Technisch gesehen können Sie für diese Anleitung entweder eine TPU oder eine GPU verwenden.
Hinweise zu TPU-Umgebungen
Google bietet drei Produkte mit TPUs an:
- Colab bietet TPU v2 kostenlos an, was für diese Anleitung ausreicht.
- Kaggle bietet TPU v3 kostenlos an und diese können auch für diese Anleitung verwendet werden.
- Cloud TPU bietet TPU v3 und neuere Generationen. So richten Sie die Funktion ein:
- Neue TPU-VM erstellen
- SSH-Portweiterleitung für den gewünschten Jupyter-Serverport einrichten
- Installieren Sie Jupyter und starten Sie es auf der TPU-VM. Stellen Sie dann über „Mit lokaler Laufzeit verbinden“ eine Verbindung zu Colab her.
Hinweise zur Einrichtung mehrerer GPUs
In dieser Anleitung liegt der Schwerpunkt auf dem TPU-Anwendungsfall. Sie können sie jedoch ganz einfach an Ihre eigenen Anforderungen anpassen, wenn Sie einen Computer mit mehreren GPUs haben.
Wenn Sie lieber mit Colab arbeiten, können Sie eine VM mit mehreren GPUs auch direkt über „Mit einer benutzerdefinierten GCE-VM verbinden“ im Menü „Verbinden“ in Colab bereitstellen.
Wir konzentrieren uns hier auf die Verwendung der kostenlosen TPU von Kaggle.
Hinweis
Kaggle-Anmeldedaten
Gemma-Modelle werden von Kaggle gehostet. Wenn Sie Gemma verwenden möchten, beantragen Sie den Zugriff auf Kaggle:
- Melden Sie sich unter kaggle.com an oder registrieren Sie sich.
- Öffnen Sie die Gemma-Modellkarte und wählen Sie „Zugriff anfordern“ aus.
- Füllen Sie das Einwilligungsformular aus und akzeptieren Sie die Nutzungsbedingungen.
Erstellen Sie dann ein API-Token, um die Kaggle API zu verwenden:
- Öffnen Sie die Kaggle-Einstellungen.
- Wählen Sie Create New Token (Neues Token erstellen) aus.
- Eine
kaggle.json
-Datei wird heruntergeladen. Es enthält Ihre Kaggle-Anmeldedaten.
Führen Sie die folgende Zelle aus und geben Sie Ihre Kaggle-Anmeldedaten ein, wenn Sie dazu aufgefordert werden.
# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub
kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…
Alternativ können Sie KAGGLE_USERNAME und KAGGLE_KEY in Ihrer Umgebung festlegen, wenn kagglehub.login() nicht funktioniert.
Installation
Installieren Sie Keras und KerasNLP mit dem Gemma-Modell.
pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras
Keras JAX-Backend einrichten
Importieren Sie JAX und führen Sie eine Plausibilitätsprüfung auf der TPU durch. Kaggle bietet TPUv3-8-Geräte mit 8 TPU-Kernen mit jeweils 16 GB Speicher an.
import jax
jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os
# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
Modell laden
import keras
import keras_nlp
Hinweise zum Training mit gemischter Genauigkeit auf NVIDIA-GPUs
Beim Training auf NVIDIA-GPUs kann mit gemischter Präzision (keras.mixed_precision.set_global_policy('mixed_bfloat16')
) das Training mit minimaler Auswirkung auf die Trainingsqualität beschleunigt werden. In den meisten Fällen empfiehlt es sich, die Verwendung von gemischter Genauigkeit zu aktivieren, da Sie so Arbeitsspeicher und Zeit sparen. Bei kleinen Batchgrößen kann die Speichernutzung jedoch um das 1, 5-Fache ansteigen, da die Gewichte zweimal geladen werden, einmal mit halber und einmal mit voller Genauigkeit.
Für die Inferenz funktioniert die Halbpräzision (keras.config.set_floatx("bfloat16")
) und spart Arbeitsspeicher, während die gemischte Präzision nicht geeignet ist.
# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')
Wenn Sie das Modell mit den Gewichten und Tensoren laden möchten, die auf TPUs verteilt sind, erstellen Sie zuerst eine neue DeviceMesh
. DeviceMesh
steht für eine Sammlung von Hardwaregeräten, die für verteilte Berechnungen konfiguriert sind. Sie wurde in Keras 3 als Teil der Unified Distribution API eingeführt.
Die Distribution API ermöglicht Daten- und Modellparallelität, was eine effiziente Skalierung von Deep-Learning-Modellen auf mehreren Beschleunigern und Hosts ermöglicht. Es nutzt das zugrunde liegende Framework (z.B. JAX), um das Programm und die Tensoren gemäß den Sharding-Anweisungen mithilfe eines Verfahrens zu verteilen, das als Single Program, Multiple Data (SPMD) Erweiterung bezeichnet wird. Weitere Informationen finden Sie im neuen API-Leitfaden für die Keras 3-Distribution.
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
(1, 8),
["batch", "model"],
devices=keras.distribution.list_devices())
LayoutMap
aus der Distributions-API gibt an, wie die Gewichte und Tensoren mithilfe der Stringschlüssel, z. B. token_embedding/embeddings
unten, gesplittet oder repliziert werden sollen. Diese werden wie reguläre Ausdrücke behandelt, um Tensorpfade abzugleichen. Übereinstimmende Tensoren werden mit Modelldimensionen (8 TPUs) geSharded; andere werden vollständig repliziert.
model_dim = "model"
layout_map = keras.distribution.LayoutMap(device_mesh)
# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
model_dim, None, None)
layout_map["decoder_block.*attention_output.*kernel"] = (
model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)
Mit ModelParallel
können Sie Modellgewichte oder Aktivierungstensoren auf alle Geräte auf der DeviceMesh
aufteilen. In diesem Fall werden einige der Gemma 7B-Modellgewichte gemäß der oben definierten layout_map
auf 8 TPU-Chips verteilt. Laden Sie das Modell nun verteilt.
model_parallel = keras.distribution.ModelParallel(
layout_map=layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
Prüfen Sie nun, ob das Modell richtig partitioniert wurde. Nehmen wir decoder_block_1
als Beispiel.
decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'> decoder_block_1/pre_attention_norm/scale (3072,) PartitionSpec(None,) decoder_block_1/attention/query/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/key/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/value/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/attention_output/kernel (16, 256, 3072) PartitionSpec(None, None, 'model') decoder_block_1/pre_ffw_norm/scale (3072,) PartitionSpec(None,) decoder_block_1/ffw_gating/kernel (3072, 24576) PartitionSpec('model', None) decoder_block_1/ffw_gating_2/kernel (3072, 24576) PartitionSpec('model', None) decoder_block_1/ffw_linear/kernel (24576, 3072) PartitionSpec(None, 'model')
Inferenz vor der Feinabstimmung
gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'
Das Modell generiert eine Liste mit empfehlenswerten Komödien aus den 90er-Jahren. Jetzt optimieren wir das Gemma-Modell, um den Ausgabestil zu ändern.
Mit IMDb optimieren
import tensorflow_datasets as tfds
imdb_train = tfds.load(
"imdb_reviews",
split="train",
as_supervised=True,
batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)
imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0... Dl Completed...: 0 url [00:00, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s] Generating train examples...: 0%| | 0/25000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord… Generating test examples...: 0%| | 0/25000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*… Generating unsupervised examples...: 0%| | 0/50000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t… Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data. b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)
Führen Sie eine Feinabstimmung mit Low Rank Adaptation (LoRA) durch. LoRA ist eine Methode zur Feinabstimmung, mit der die Anzahl der trainierbaren Parameter für nachfolgende Aufgaben erheblich reduziert wird. Dazu werden die gesamten Gewichte des Modells eingefroren und eine kleinere Anzahl neuer trainierbarer Gewichte in das Modell eingefügt. Im Grunde reparameterisiert LoRA die größeren vollständigen Gewichtungsmatrizen durch zwei kleinere Low-Rank-Matrizen AxB, um sie zu trainieren. Diese Technik macht das Training viel schneller und speichereffizienter.
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.
# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" 2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329 <keras.src.callbacks.history.History at 0x7e9cac7f41c0>
Hinweis: Wenn Sie LoRA aktivieren, wird die Anzahl der trainierbaren Parameter von 7 Milliarden auf nur 11 Millionen reduziert.
Inferenz nach der Feinabstimmung
gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."
Nach der Feinabstimmung hat das Modell den Stil von Filmrezensionen gelernt und generiert jetzt im Kontext von Komödienfilmen aus den 90er-Jahren Rezensionen in diesem Stil.
Nächste Schritte
In dieser Anleitung haben Sie gelernt, wie Sie mit dem KerasNLP JAX-Backend ein Gemma-Modell für das IMDb-Dataset auf verteilten leistungsstarken TPUs optimieren. Hier einige Vorschläge, was Sie sonst noch lernen können: