Dostrajanie rozproszone za pomocą Gemmy przy użyciu Keras

Wyświetl na ai.google.dev Uruchom w Google Colab Uruchom w Kaggle Otwórz w Vertex AI Wyświetl źródło w GitHubie

Omówienie

Gemma to rodzina lekkich, nowoczesnych modeli otwartych opartych na badaniach i technologii używanych do tworzenia modeli Gemini od Google. Gemmę można dodatkowo dostosować do konkretnych potrzeb. Duże modele językowe, takie jak Gemma, mogą być jednak bardzo duże, a niektóre z nich mogą nie mieścić się na akceleratorze śpiewania w celu dostrajania. W takim przypadku istnieją 2 ogólne sposoby ich dostrajania:

  1. Efektywne dostrajanie parametrów (PEFT), które ma zmniejszyć efektywny rozmiar modelu przy jednoczesnym obniżeniu dokładności. Do tej kategorii należy LoRA. Samouczek dostrajania modeli Gemma w Kera przy użyciu LoRA pokazuje, jak dostrajać model Gemma 2B gemma_2b_en z użyciem LoRA przy użyciu KerasNLP z użyciem pojedynczego GPU.
  2. Pełne dostrajanie parametrów dzięki równoległości modelu. Równoległość modelu rozdziela wagi pojedynczego modelu na wiele urządzeń i umożliwia skalowanie poziome. Więcej informacji o trenowaniu rozproszonym znajdziesz w tym przewodniku Keeras.

Ten samouczek przedstawia, jak używać Keras z backendem JAX do dostrajania modelu Gemma 7B z architekturą LoRA i rozproszonym trenowaniem rozproszonym za pomocą procesora Tensor (TPU) Google. Aby uzyskać wolniejsze, ale dokładniejsze dostrajanie pełnych parametrów, funkcję LoRA można wyłączyć w tym samouczku.

Korzystanie z akceleratorów

Technicznie rzecz biorąc, w tym samouczku możesz użyć TPU lub GPU.

Uwagi dotyczące środowisk TPU

Google oferuje 3 usługi, które dostarczają TPU:

  • Colab udostępnia bezpłatnie TPU v2 (co jest wystarczające w tym samouczku).
  • Kaggle udostępnia bezpłatnie układ TPU v3. Można z nich korzystać również w tym samouczku.
  • Cloud TPU oferuje TPU v3 i nowsze generacje. Możesz to zrobić na przykład:
    1. Utwórz nową maszynę wirtualną TPU.
    2. Skonfiguruj przekierowanie portów SSH dla odpowiedniego portu serwera Jupyter
    3. Zainstaluj serwer Jupyter i uruchom go w maszynie wirtualnej TPU, a następnie połącz się z Colab za pomocą opcji „Połącz z lokalnym środowiskiem wykonawczym”

Uwagi dotyczące konfiguracji z wieloma procesorami graficznymi

Chociaż ten samouczek skupia się na przypadku użycia TPU, możesz go łatwo dostosować do własnych potrzeb, jeśli masz maszynę z wieloma procesorami graficznymi.

Jeśli wolisz korzystać z Colab, możesz też udostępnić maszynę wirtualną z wieloma GPU na potrzeby Colab bezpośrednio za pomocą opcji „Połącz z niestandardową maszyną wirtualną GCE”. w menu Colab Connect.

Skupimy się na wykorzystaniu bezpłatnego TPU stworzonego przez Kaggle.

Zanim zaczniesz

Dane logowania Kaggle

Modele Gemma są hostowane przez Kaggle. Aby korzystać z Gemma, poproś o dostęp w Kaggle:

  • Zaloguj się lub zarejestruj na stronie kaggle.com.
  • Otwórz kartę modelu Gemma i wybierz „Poproś o dostęp”.
  • Wypełnij formularz zgody i zaakceptuj Warunki korzystania z usługi

Następnie, aby używać interfejsu Kaggle API, utwórz token API:

  • Otwórz ustawienia Kaggle.
  • Wybierz „Create New Token” (Utwórz nowy token).
  • Pobrano plik kaggle.json. Zawiera Twoje dane logowania Kaggle

Uruchom następującą komórkę i wpisz dane logowania Kaggle, gdy pojawi się prośba.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Jeśli metoda kagglehub.login() nie działa, możesz też ustawić w środowisku wartości KAGGLE_USERNAME i KAGGLE_KEY.

Instalacja

Zainstaluj Keras i KerasNLP z użyciem modelu Gemma.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Skonfiguruj backend Keras JAX

Zaimportuj JAX i przeprowadź kontrolę poprawności TPU. Kaggle oferuje urządzenia TPUv3-8 z 8 rdzeniami TPU i 16 GB pamięci każde z nich.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Wczytaj model

import keras
import keras_nlp

Uwagi dotyczące mieszanej precyzji trenowania w procesorach graficznych NVIDIA

Podczas trenowania w procesorach graficznych NVIDIA można użyć mieszanej precyzji (keras.mixed_precision.set_global_policy('mixed_bfloat16')), aby przyspieszyć trenowanie z minimalnym wpływem na jakość trenowania. W większości przypadków zalecamy włączenie mieszanej precyzji, ponieważ oszczędza to zarówno pamięć, jak i czas. Pamiętaj jednak, że przy niewielkich wsarach może to zwiększyć wykorzystanie pamięci o 1,5 raza (wagi będą wczytywane dwukrotnie z dokładnością do połowy i precyzji).

W przypadku wnioskowania połowa precyzji (keras.config.set_floatx("bfloat16")) będzie działać i oszczędzać pamięć, a mieszana precyzja nie jest dostępna.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Aby wczytać model z wagami i tensorami rozłożonymi na TPU, najpierw utwórz nowy element DeviceMesh. DeviceMesh to zbiór urządzeń skonfigurowanych pod kątem obliczeń rozproszonych i wprowadzono go w Keraście 3 w ramach ujednoliconego interfejsu API dystrybucji.

Rozkład API umożliwia korzystanie z równoległości danych i modeli, co pozwala na efektywne skalowanie modeli deep learning w wielu akceleratorach i hostach. Wykorzystuje bazową platformę (np. JAX) do dystrybucji programu i tensorów zgodnie z dyrektywami do fragmentacji, korzystając z procedury nazywanej „pojedynczym programem i rozszerzaniem wielu danych” (SPMD). Więcej informacji znajdziesz w nowym przewodniku po interfejsie API dystrybucji Keras 3.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

Pole LayoutMap z interfejsu Distribution API określa, w jaki sposób wagi i tensory powinny być dzielone lub replikowane za pomocą kluczy ciągu znaków, np. token_embedding/embeddings poniżej, które są traktowane jak wyrażenie regularne w celu dopasowania ścieżek tensora. Dopasowane tensory są dzielone na fragmenty z uwzględnieniem wymiarów modelu (8 TPU). a inne zostaną w pełni zreplikowane.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

Funkcja ModelParallel umożliwia fragmentowanie wag modelu lub tensorów aktywacji we wszystkich odchyleniach w DeviceMesh. W tym przypadku niektóre wagi modelu Gemma 7B są podzielone na 8 układów TPU zgodnie z layout_map określonym powyżej. Teraz wczytaj model w sposób rozproszony.

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

Teraz sprawdź, czy model został poprawnie podzielony na partycje. Dla przykładu przyjrzyjmy się decoder_block_1.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

Wnioskowanie przed dostrajeniem

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

Model generuje listę świetnych komedii z lat 90. ubiegłego wieku. Teraz dostrajamy model Gemma, aby zmienić styl wyjściowy.

Dostrajanie z IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

Dostrajanie za pomocą adaptacji niskiego rankingu (LoRA). LoRA to technika dostrajania, która znacznie zmniejsza liczbę parametrów możliwych do wytrenowania na kolejnych zadaniach przez zablokowanie pełnych ciężarów modelu i wstawienie do modelu mniejszej liczby nowych wag dostępnych do trenowania. Zasadniczo LoRA zmienia parametry większych macierzy o pełnej wagi, aby wytrenować 2 mniejsze macierze niskiego rangi AxB. Ta technika sprawia, że trenowanie jest znacznie szybsze i bardziej oszczędza pamięć.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

Pamiętaj, że włączenie LoRA znacznie zmniejsza liczbę parametrów z możliwością trenowania – z 7 mld do 11 milionów.

Wnioskowanie po dostrajaniu

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

Po dostrojeniu model nauczył się stylu recenzentów filmów i obecnie generuje go w kontekście komedii z lat 90.

Co dalej?

W tym samouczku pokazaliśmy Ci, jak przy użyciu backendu KerasNLP JAX dostrajać model Gemma w zbiorze danych IMDb w rozproszony sposób w wydajnych jednostkach TPU. Oto kilka sugestii, o których warto się jeszcze dowiedzieć: