Keras kullanılarak Gemma ile dağıtılan ayarlama

ai.google.dev adresinde görüntüleme Google Colab'da çalıştırma Kaggle'da çalıştırma Vertex AI'da aç Kaynağı GitHub'da görüntüleyin

Genel Bakış

Gemma, Google Gemini modellerini oluşturmakta kullanılan araştırmalar ve teknolojiyle oluşturulmuş hafif, son teknoloji ürünü açık modeller ailesidir. Gemma, belirli ihtiyaçlara uyacak şekilde daha fazla ayarlanabilir. Ancak Gemma gibi büyük dil modelleri çok büyük olabilir ve bazılarının ince ayar için tek bir hızlandırıcıya sığmaması mümkündür. Bu durumda, ince ayar yapmak için iki genel yaklaşım vardır:

  1. Bazı doğruluklardan ödün vererek etkili model boyutunu küçültmeyi amaçlayan Parametreleri Verimli Kullanarak İnce Ayarlama (PEFT). LoRA bu kategoride yer alır. LoRA'yı kullanarak Keras'taki Gemma modellerinde ince ayar yapma başlıklı eğiticide, tek bir GPU'da KerasNLP kullanarak LoRA ile gemma_2b_en Gemma 2B modelinde nasıl ince ayar yapılacağı gösterilmektedir.
  2. Model paralelliğiyle tam parametre hassas ayarı. Model paralelliği, tek bir modelin ağırlıklarını birden fazla cihaza dağıtır ve yatay ölçeklendirmeyi etkinleştirir. Dağıtılmış eğitim hakkında daha fazla bilgiyi bu Keras kılavuzunda bulabilirsiniz.

Bu eğitimde, Google'ın Tensor İşleme Birimi'nde (TPU) LoRA ve model paralelliği dağıtılmış eğitimi ile Gemma 7B modelinde ince ayar yapmak için Keras'ı JAX arka ucuyla kullanma konusunda size yol gösterilmektedir. Daha yavaş ancak daha doğru tam parametre ayarı için bu eğitimde LoRA'nın devre dışı bırakılabileceğini unutmayın.

Hızlandırıcıları kullanma

Teknik olarak bu eğitim için TPU veya GPU kullanabilirsiniz.

TPU ortamlarıyla ilgili notlar

Google'ın TPU sağlayan 3 ürünü vardır:

  • Colab, bu eğitim için yeterli olan TPU v2'yi ücretsiz sunar.
  • Kaggle, TPU v3'ü ücretsiz olarak sunar ve bu eğitimde de kullanılabilir.
  • Cloud TPU, TPU v3 ve daha yeni nesilleri sunar. Bunu ayarlama yöntemlerinden biri:
    1. Yeni bir TPU sanal makinesi oluşturun
    2. Aktarmak istediğiniz Jupyter sunucusu bağlantı noktası için SSH bağlantı noktası yönlendirmeyi ayarlayın.
    3. Jupyter'ı yükleyip TPU sanal makinesinde başlatın, ardından "Yerel çalışma zamanına bağlan" seçeneğini kullanarak Colab'a bağlanın.

Çoklu GPU kurulumuyla ilgili notlar

Bu eğitimde TPU kullanım alanına odaklanılmış olsa da birden fazla GPU'ya sahip bir makineniz varsa bu eğitimi kendi ihtiyaçlarınıza göre kolayca uyarlayabilirsiniz.

Colab üzerinden çalışmayı tercih ediyorsanız Colab için doğrudan Colab Connect menüsündeki "Özel bir GCE sanal makinesine bağlan " seçeneğiyle çok GPU'lu bir sanal makine de sağlayabilirsiniz.

Burada Kaggle'daki ücretsiz TPU'yu kullanmaya odaklanacağız.

Başlamadan önce

Kaggle kimlik bilgileri

Gemma modelleri Kaggle tarafından barındırılır. Gemma'yı kullanmak için Kaggle'da erişim isteğinde bulunun:

  • kaggle.com adresinden oturum açın veya kaydolun.
  • Gemma model kartını açın ve "Erişim İste"'yi seçin.
  • İzin formunu doldurup şartlar ve koşulları kabul edin

Ardından, Kaggle API'yi kullanmak için bir API jetonu oluşturun:

  • Kaggle ayarları'nı açın.
  • "Yeni Jeton Oluştur"'u seçin.
  • kaggle.json dosyası indirilir. Kaggle kimlik bilgilerinizi içerir

Aşağıdaki hücreyi çalıştırın ve istendiğinde Kaggle kimlik bilgilerinizi girin.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

kagglehub.login() sizin için işe yaramazsa alternatif bir yöntem olarak ortamınızda KAGGLE_USERNAME ve KAGGLE_KEY ayarlayabilirsiniz.

Kurulum

Gemma modeliyle birlikte Keras ve KerasNLP'yi yükleyin.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Keras JAX arka ucunu ayarlama

JAX'i içe aktarın ve TPU'da bütünlük kontrolü çalıştırın. Kaggle, her biri 16 GB belleğe sahip 8 TPU çekirdeğine sahip TPUv3-8 cihazları sunar.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Modeli yükleme

import keras
import keras_nlp

NVIDIA GPU'larda karma hassasiyetli eğitimle ilgili notlar

NVIDIA GPU'larda eğitim yapılırken eğitim kalitesinde minimum etkiyle eğitimi hızlandırmak için karma hassasiyet (keras.mixed_precision.set_global_policy('mixed_bfloat16')) kullanılabilir. Çoğu durumda, hem bellek hem de zaman tasarrufu sağladığı için karma hassasiyetin etkinleştirilmesi önerilir. Ancak küçük toplu boyutlarda bellek kullanımının 1,5 kat artabileceğini unutmayın (ağırlıklar yarım hassasiyet ve tam hassasiyette iki kez yüklenir).

Karma hassasiyet geçerli olmadığından, çıkarım için yarı hassasiyet (keras.config.set_floatx("bfloat16")) kullanılabilir ve bellekte tasarruf sağlar.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Modeli, TPU'lara dağıtılan ağırlıklar ve tenzorlarla yüklemek için önce yeni bir DeviceMesh oluşturun. DeviceMesh, dağıtılmış hesaplama için yapılandırılmış bir donanım cihazları koleksiyonunu temsil eder ve birleşik dağıtım API'sinin bir parçası olarak Keras 3'te kullanıma sunulmuştur.

Dağıtım API'si, veri ve model paralelliğini etkinleştirerek derin öğrenme modellerinin birden fazla hızlandırıcı ve ana makinede verimli bir şekilde ölçeklendirilmesine olanak tanır. Tek program, çoklu veri (SPMD) genişletmesi adı verilen bir prosedür yoluyla programı ve tensörleri parçalama yönergelerine göre dağıtmak için temel çerçeveden (ör. JAX) yararlanır. Daha fazla bilgi için yeni Keras 3 dağıtım API kılavuzunu inceleyin.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

Dağıtım API'sindeki LayoutMap, ağırlıkların ve tenzorların, aşağıdaki token_embedding/embeddings gibi dize anahtarları kullanılarak nasıl bölüneceğini veya çoğaltılacağını belirtir. Bu anahtarlar, tenzor yollarıyla eşleşmek için normal ifade gibi işlenir. Eşleşen tensörler model boyutlarıyla (8 TPU), diğerleri tamamen çoğaltılır.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel, DeviceMesh üzerindeki tüm cihazlarda model ağırlıklarını veya aktivasyon tensörlerini kırmanıza olanak tanır. Bu durumda, Gemma 7B model ağırlıklarından bazıları yukarıda tanımlanan layout_map değerine göre 8 TPU çipi arasında bölünür. Şimdi modeli dağıtılmış şekilde yükleyin.

model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

Şimdi de modelin doğru şekilde bölümlendirildiğini doğrulayın. Örnek olarak decoder_block_1 öğesini ele alalım.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

İnce ayar yapmadan önce çıkarım

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

Model, 90'lı yıllarda çekilmiş ve izlenebilecek en iyi komedi filmlerinin listesini oluşturur. Şimdi, çıkış stilini değiştirmek için Gemma modelinde ince ayar yapıyoruz.

IMDb ile hassas ayarlama

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

Düşük Sıralama Uyarlaması'nı (LoRA) kullanarak hassas ayarlama yapın. LoRA, modelin tüm ağırlıklarını dondurarak ve modele daha az sayıda yeni eğitilebilir ağırlık ekleyerek aşağı akış görevleri için eğitilebilir parametrelerin sayısını büyük ölçüde azaltan bir ince ayar tekniğidir. Temel olarak LoRA, eğitmek için daha büyük tam ağırlık matrislerini 2 küçük düşük rütbeli AxB matrisiyle yeniden parametrelendirir ve bu teknik, eğitimi çok daha hızlı ve daha bellek verimli hale getirir.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

LoRA'yı etkinleştirmenin, eğitilebilir parametrelerin sayısını 7 milyardan yalnızca 11 milyona düşürdüğünü unutmayın.

İnce ayardan sonra çıkarım

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

İnce ayarlama yapıldıktan sonra model, film yorumlarının tarzını öğrenmiştir ve artık 90'ların komedi filmleri bağlamında bu tarzda sonuçlar üretmektedir.

Sırada ne var?

Bu eğiticide, IMDb veri kümesi üzerinde bir Gemma modeli üzerinde güçlü TPU'lar üzerinde dağıtılmış şekilde ince ayar yapmak için KerasNLP JAX arka ucunun nasıl kullanılacağını öğrendiniz. Öğrenebileceğiniz diğer konularla ilgili birkaç öneriyi aşağıda bulabilirsiniz: