Dostrajanie rozproszone za pomocą Gemmy przy użyciu Keras

Wyświetl na stronie ai.google.dev Uruchom w Google Colab Uruchom w Kaggle Otwórz w Vertex AI Wyświetl źródło w GitHubie

Omówienie

Gemma to rodzina lekkich, otwartych modeli opartych na najnowszych badaniach i technologiach wykorzystanych do tworzenia modeli Google Gemini. Gemma może być dalej dostosowywana do konkretnych potrzeb. Duże modele językowe, takie jak Gemma, mogą być bardzo duże, a niektóre z nich mogą nie zmieścić się na akceleratorze singla do dopracowywania. W takim przypadku możesz je dostosować na 2 sposoby:

  1. dostrajanie konkretnych parametrów (PEFT), które ma na celu zmniejszenie rozmiaru modelu poprzez poświęcenie części wierności. Do tej kategorii należy LoRA. Samouczek Dostosowywanie modeli Gemma w Keras za pomocą LoRA pokazuje, jak dostosować model Gemma 2B gemma_2b_en za pomocą LoRA przy użyciu KerasNLP na jednym GPU.
  2. Pełne dostrajanie parametrów z paralelizmem modelu. Parallizm modeli polega na rozprowadzaniu wag pojedynczego modelu na wiele urządzeń i umożliwia skalowanie poziome. Więcej informacji o trenowaniu rozproszonym znajdziesz w tym przewodniku Keeras.

W tym samouczku dowiesz się, jak użyć Keras z backendem JAX do dopracowania modelu Gemma 7B z LoRA i rozproszonego trenowania z modelem równoległym na Tensor Processing Unit (TPU) Google. Aby uzyskać wolniejsze, ale dokładniejsze dostrajanie pełnych parametrów, funkcję LoRA można wyłączyć w tym samouczku.

Korzystanie z akceleratorów

W tym samouczku możesz technicznie użyć TPU lub GPU.

Uwagi dotyczące środowisk TPU

Google oferuje 3 usługi, które zapewniają TPU:

  • Colab udostępnia bezpłatnie TPU v2, co wystarczy do tego samouczka.
  • Kaggle oferuje bezpłatnie TPU v3, które można wykorzystać w tym samouczku.
  • Cloud TPU oferuje TPU w wersji 3 i nowszych. Oto jeden ze sposobów konfiguracji:
    1. Utwórz nową maszynę wirtualną TPU.
    2. Skonfiguruj przekierowanie portu SSH dla wybranego portu serwera Jupyter.
    3. Zainstaluj Jupyter i uruchom go na maszynie wirtualnej TPU, a następnie połącz się z Colab za pomocą opcji „Połącz z lokalnym środowiskiem wykonawczym”.

Uwagi dotyczące konfiguracji z wieloma procesorami graficznymi

Chociaż ten samouczek skupia się na przypadku użycia TPU, możesz go łatwo dostosować do własnych potrzeb, jeśli masz maszynę z wieloma procesorami graficznymi.

Jeśli wolisz korzystać z Colab, możesz też udostępnić maszynę wirtualną z wieloma procesorami graficznymi na potrzeby Colab bezpośrednio za pomocą opcji „Połącz się z niestandardową maszyną wirtualną GCE” w menu Colab Connect.

W tym artykule skupimy się na bezpłatnym TPU z Kaggle.

Zanim zaczniesz

Dane logowania do Kaggle

Modele Gemma są hostowane w Kaggle. Aby korzystać z Gemma, poproś o dostęp na Kaggle:

  • Zaloguj się lub zarejestruj na stronie kaggle.com.
  • Otwórz kartę modelu Gemma i wybierz „Request Access” (Poproś o dostęp).
  • Wypełnij formularz zgody i zaakceptuj warunki korzystania z usługi.

Aby korzystać z interfejsu Kaggle API, utwórz token API:

  • Otwórz ustawienia Kaggle.
  • Kliknij „Utwórz nowy token”.
  • Pobrano plik kaggle.json. Zawiera Twoje dane logowania do Kaggle

Uruchom następującą komórkę i wpisz dane logowania Kaggle, gdy pojawi się prośba.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Jeśli funkcja kagglehub.login() nie działa, możesz spróbować ustawić zmienne KAGGLE_USERNAME i KAGGLE_KEY w swoim środowisku.

Instalacja

Zainstaluj Keras i KerasNLP z użyciem modelu Gemma.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Konfigurowanie backendu Keras JAX

Zaimportuj JAX i przeprowadź kontrolę poprawności na TPU. Kaggle oferuje urządzenia TPUv3-8 z 8 rdzeniami TPU i 16 GB pamięci na każde z nich.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Wczytaj model

import keras
import keras_nlp

Informacje o trenowaniu z użyciem różnych poziomów precyzji na procesorach graficznych NVIDIA

Podczas trenowania na GPU firmy NVIDIA można użyć mieszanej precyzji (keras.mixed_precision.set_global_policy('mixed_bfloat16')), aby przyspieszyć trenowanie przy minimalnym wpływie na jakość. W większości przypadków zalecamy włączenie mieszanej precyzji, ponieważ oszczędza to zarówno pamięć, jak i czas. Pamiętaj jednak, że przy małych rozmiarach partii może to zwiększyć wykorzystanie pamięci 1, 5 raza (wagi zostaną załadowane dwukrotnie, z półpełną i pełną dokładnością).

W przypadku wnioskowania użyj precyzji połowicznej (keras.config.set_floatx("bfloat16")), która zaoszczędzi pamięć, a nie precyzji mieszanej.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Aby załadować model z wagami i tensorami rozłożonymi na TPU, najpierw utwórz nowy DeviceMesh. DeviceMesh reprezentuje zbiór urządzeń sprzętowych skonfigurowanych do obliczeń rozproszonych i został wprowadzony w Keras 3 jako część interfejsu Unified Distribution API.

Interfejs Distribution API umożliwia równoległość danych i modeli, co pozwala na efektywne skalowanie modeli uczenia głębokiego na wielu akceleratorach i hostach. Korzysta on z podstawowej platformy (np. JAX), aby rozprowadzać program i tensory zgodnie z instrukcjami podziału za pomocą procedury zwanej poszerzaniem w ramach pojedynczego programu z wieloma danymi (SPMD). Więcej informacji znajdziesz w nowym przewodniku po interfejsie API dystrybucji Keras 3.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

Pole LayoutMap z interfejsu Distribution API określa, w jaki sposób wagi i tensory powinny być dzielone lub replikowane za pomocą kluczy ciągu znaków, np. token_embedding/embeddings poniżej, które są traktowane jak wyrażenie regularne w celu dopasowania ścieżek tensora. Dopasowane tensory są dzielone według wymiarów modelu (8 procesorów TPU); inne są w pełni powielane.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel umożliwia dzielenie wag modelu lub tensorów aktywacji na wszystkie urządzenia w DeviceMesh. W tym przypadku niektóre wartości wagowe modelu Gemma 7B są dzielone na 8 chipów TPU zgodnie z definicją layout_map podaną powyżej. Teraz wczytaj model w sposób rozproszony.

model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

Teraz sprawdź, czy model został podzielony prawidłowo. Weźmy na przykład decoder_block_1.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

Wnioskowanie przed dostrojem

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

Model generuje listę świetnych komedii z lat 90., które warto obejrzeć. Teraz dostosowujemy model Gemma, aby zmienić styl wyjściowy.

Dostrajanie z IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

Dokonać dostosowania za pomocą adaptacji niskiego rzędu (LoRA). LoRA to technika dostrajania, która znacznie zmniejsza liczbę parametrów możliwych do wytrenowania na kolejnych zadaniach przez zamrożenie pełnych ciężarów modelu i wstawienie do modelu mniejszej liczby nowych wag dostępnych do trenowania. W podstawie LoRA polega na parametryzacji większych pełnych macierzy wag za pomocą 2 mniejszych macierzy niskiego rzędu AxB, które służą do trenowania. Dzięki tej metodzie trenowanie jest znacznie szybsze i bardziej efektywne pod względem pamięci.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

Pamiętaj, że włączenie LoRA znacznie zmniejsza liczbę parametrów, które można trenować, z 7 mld do zaledwie 11 mln.

Wnioskowanie po dostrajaniu

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

Po dopracowaniu model nauczył się stylu recenzji filmowych i generuje teraz wyniki w tym stylu w kontekście komedii z lat 90. XX w.

Co dalej?

Z tego samouczka dowiesz się, jak użyć backendu KerasNLP JAX do dostrojenia modelu Gemma na zbiorze danych IMDb w sposób rozproszony na wydajnych procesorach TPU. Oto kilka sugestii dotyczących tego, czego jeszcze możesz się nauczyć: