Dostrajanie rozproszone za pomocą Gemmy przy użyciu Keras

Zobacz na ai.google.dev Uruchom w Kaggle Otwórz w Vertex AI Wyświetl źródło na GitHubie

Przegląd

Gemma to rodzina lekkich, najnowocześniejszych otwartych modeli stworzonych na podstawie badań i technologii używanych do tworzenia modeli Google Gemini. Gemma można dodatkowo dostosować do konkretnych potrzeb. Jednak duże modele językowe, takie jak Gemma, mogą być bardzo duże i niektóre z nich nie mieszczą się w akceleratorze śpiewnym do dostrajania. W tym przypadku można zastosować 2 ogólne podejścia:

  1. Efektywne dostrajanie parametrów (PEFT), które ma na celu zmniejszenie efektywnego rozmiaru modelu przez poświęcenie pewnej dokładności. Do tej kategorii należy LoRA, a samouczek dostrajania modeli Gemma w Keras przy użyciu LoRA pokazuje, jak dostroić model gemma_2b_en Gemma 2B za pomocą LoRA za pomocą KerasNLP z użyciem pojedynczego GPU.
  2. Pełne dostrajanie parametrów z użyciem równoległości modelu. Równoległość modelu rozkłada wagi pojedynczego modelu na wiele urządzeń i umożliwia skalowanie w poziomie. Więcej informacji o trenowaniu rozproszonym znajdziesz w tym przewodniku po Keras.

Ten samouczek przedstawia, jak używać Keras z backendem JAX do dostrajania modelu Gemma 7B z użyciem LoRA i rozproszonego trenowania paralizmów w Google Tensor Processing Unit (TPU). Uwaga: w tym samouczku można wyłączyć LoRA, aby przyspieszyć dostrajanie pełnych parametrów.

Zastosowanie akceleratorów

Technicznie rzecz biorąc, na potrzeby tego samouczka możesz użyć TPU lub GPU.

Uwagi dotyczące środowisk TPU

Google oferuje 3 usługi zapewniające jednostki TPU:

  • Colab zapewnia dostęp do TPU w wersji 2, która nie jest wystarczająca w tym samouczku.
  • Kaggle oferuje TPU v3 bezpłatnie i działa w tym samouczku.
  • Cloud TPU obejmuje TPU v3 i nowsze generacje. Możesz to zrobić na jeden z tych sposobów:
    1. Tworzenie nowej maszyny wirtualnej TPU
    2. Skonfiguruj przekierowywanie portów SSH dla odpowiedniego portu serwera Jupyter
    3. Zainstaluj serwer Jupyter i uruchom go w maszynie wirtualnej TPU, a następnie połącz się z Colab za pomocą opcji „Połącz z lokalnym środowiskiem wykonawczym”

Uwagi dotyczące konfiguracji z wieloma procesorami graficznymi

Chociaż ten samouczek dotyczy zastosowania TPU, z łatwością dostosujesz go do swoich potrzeb, jeśli masz maszynę z wieloma procesorami graficznymi.

Jeśli wolisz korzystać z Colab, możesz też bezpośrednio udostępnić maszynę wirtualną z wieloma procesorami graficznymi na potrzeby Colab. Aby to zrobić, w menu Colab Connect kliknij „Połącz z niestandardową maszyną wirtualną GCE”.

W tym miejscu skupimy się na korzystaniu z bezpłatnej TPU od Kaggle.

Zanim zaczniesz

Dane logowania do Kaggle

Modele Gemma są hostowane przez Kaggle. Aby używać Gemma, poproś o dostęp w Kaggle:

  • Zaloguj się lub zarejestruj na kaggle.com.
  • Otwórz kartę modelu Gemma i wybierz „Poproś o dostęp”.
  • Wypełnić formularz zgody i zaakceptować Warunki korzystania z usługi

Następnie, aby korzystać z interfejsu Kaggle API, utwórz token API:

  • Otwórz ustawienia Kaggle.
  • Wybierz Create New Token (Utwórz nowy token).
  • Pobrano plik kaggle.json. Zawiera Twoje dane logowania do Kaggle

Uruchom poniższą komórkę, a gdy pojawi się prośba, wpisz swoje dane logowania do Kaggle.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Jeśli metoda kagglehub.login() nie działa, możesz też ustawić w środowisku KAGGLE_USERNAME i KAGGLE_KEY.

Instalacja

Instalowanie Keras i KerasNLP z użyciem modelu Gemma.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Skonfiguruj backend JAX Keras

Zaimportuj JAX i sprawdź poprawność TPU. Kaggle oferuje urządzenia z TPUv3-8, które mają 8 rdzeni TPU i 16 GB pamięci.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Wczytaj model

import keras
import keras_nlp

Uwagi na temat trenowania z mieszaną precyzją w przypadku procesorów graficznych NVIDIA

Podczas trenowania na procesorach graficznych NVIDIA można używać mieszanej precyzji (keras.mixed_precision.set_global_policy('mixed_bfloat16')), aby przyspieszyć trenowanie przy minimalnym wpływie na jego jakość. W większości przypadków zalecamy włączenie mieszanej precyzji, ponieważ pozwala to zaoszczędzić zarówno pamięć, jak i czas. Pamiętaj jednak, że przy małych wsadach może to zwiększyć wykorzystanie pamięci 1, 5 raza (wagi będą wczytywane dwukrotnie, co oznacza połowę dokładności i pełnej precyzji).

Aby wnioskować, połowa precyzji (keras.config.set_floatx("bfloat16")) będzie działać i oszczędzać pamięć, podczas gdy mieszana precyzja nie jest stosowana.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Aby wczytać model z wagami i tensorami rozłożonymi między TPU, najpierw utwórz nowy obiekt DeviceMesh. DeviceMesh to zbiór urządzeń skonfigurowanych pod kątem obliczeń rozproszonych i wprowadzony w Keras 3 w ramach ujednoliconego interfejsu API dystrybucji.

Interfejs API dystrybucji umożliwia równoległość danych i modeli, co pozwala na skuteczne skalowanie modeli deep learning na wielu akceleratorach i hostach. Wykorzystuje podstawową strukturę (np. JAX) do dystrybucji programu i tensorów zgodnie z dyrektywami dotyczącymi fragmentacji w ramach procedury nazywanej jednym programem, rozwijaniem z użyciem wielu danych (SPMD). Więcej informacji znajdziesz w nowym przewodniku po interfejsie API dystrybucji Keras 3.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

Parametr LayoutMap z interfejsu API dystrybucji określa sposób, w jaki wagi i tensory powinny być poddawane fragmentacji lub replikacji, korzystając z kluczy ciągów, np. token_embedding/embeddings poniżej, które są traktowane jak wyrażenie regularne pasujące do ścieżek tensorów. Dopasowane tensory są fragmentowane w ramach wymiarów modelu (8 TPU). Inne zostaną w pełni zreplikowane.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel umożliwia fragmentowanie wag modelu lub tensorów aktywacji we wszystkich środowiskach w DeviceMesh. W tym przypadku niektóre wagi modelu Gemma 7B są podzielone na 8 układów TPU zgodnie z definicją layout_map. Teraz załaduj model w sposób rozproszony.

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

Teraz sprawdź, czy model został prawidłowo partycjonowany. Weźmy za przykład decoder_block_1.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

Wnioskowanie przed dostrojeniem

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

Model generuje listę świetnych komedii z lat 90. do obejrzenia. Teraz dostrajamy model Gemma, aby zmienić styl wyjściowy.

Dostrój za pomocą IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

Przeprowadź dostrajanie, korzystając z metody Low Rank Adaptation (LoRA). LoRA to technika dostrajania, która znacznie zmniejsza liczbę możliwych do wytrenowania parametrów w zadaniach na dalszych etapach przez zablokowanie pełnych wag modelu i wstawienie do niego mniejszej liczby nowych wag możliwych do wytrenowania. LoRA dostosowuje do trenowania większe macierze pełnej wagi, wykorzystując do trenowania 2 mniejsze macierze AxB o niskiej pozycji w rankingu. Ta technika sprawia, że trenowanie jest szybsze i bardziej efektywne.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

Pamiętaj, że włączenie LoRA znacznie zmniejsza liczbę parametrów do trenowania – z 7 miliardów do zaledwie 11 milionów.

Wnioskowanie po dostrajaniu

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

Po dostrojeniu model nauczył się stylu recenzji filmów i teraz generuje w tym stylu wyniki komediowe z lat 90.

Co dalej

W tym samouczku pokazaliśmy, jak korzystać z backendu JAX KerasNLP do dostrajania modelu Gemma w zbiorze danych IMDb w rozproszony sposób w mocnych jednostkach TPU. Oto kilka sugestii, o których jeszcze warto się dowiedzieć: