ai.google.dev'de görüntüleyin | Google Colab'de çalıştır | Kaggle'da koş | Vertex AI'da aç | Kaynağı GitHub'da görüntüleyin |
Genel bakış
Gemma, Google Gemini modelleri oluşturmak için kullanılan araştırma ve teknolojiden yararlanarak hafif, modern ve açık modellerden oluşan bir ailedir. Gemma'ya özel ihtiyaçlara göre daha da ince ayar yapılabilir. Ancak Gemma gibi Büyük Dil Modellerinin boyutu çok büyük olabilir ve bunların bazıları ince ayar yapmak için bir şarkı hızlandırıcıya sığmayabilir. Bu durumda, ince ayar yapmak için iki genel yaklaşım vardır:
- Kaliteden ödün vererek etkili model boyutunu küçültmeyi amaçlayan Parametre Verimli Hassas Ayarlama (PEFT). LoRA bu kategoriye girer. LoRA kullanarak Keras'ta Gemma modellerinde ince ayar yapma eğiticisinde, tek bir GPU'da KerasNLP kullanarak
gemma_2b_en
Gemma 2B model için LoRA ile ince ayar yapma gösterilmektedir. - Model benzerliği ile tam parametre ince ayarı. Model paralelliği, tek bir modelin ağırlıklarını birden fazla cihaza dağıtır ve yatay ölçeklendirmeye olanak sağlar. Dağıtılmış eğitimler hakkında daha fazla bilgiyi bu Keras kılavuzunda bulabilirsiniz.
Bu eğitim, Google'ın Tensor İşleme Birimi (TPU) üzerinde LoRA ve model paralizm dağıtılmış eğitim ile Gemma 7B modelini hassaslaştırmak için JAX arka ucuyla Keras'ı kullanma konusunda size yol gösterir. Daha yavaş ancak daha doğru tam parametre ayarlaması için bu eğiticide LoRA'nın kapatılabileceğini unutmayın.
Hızlandırıcıları kullanma
Teknik olarak, bu eğitim için TPU veya GPU kullanabilirsiniz.
TPU ortamları hakkında notlar
Google'ın TPU sağlayan 3 ürünü vardır:
- Colab, TPU v2'yi ücretsiz olarak sunar. Bu da eğitim için yeterlidir.
- Kaggle, TPU v3'ü ücretsiz olarak sunar ve bu eğitimde de kullanılabilir.
- Cloud TPU, TPU v3 ve daha yeni nesiller sunar. Bunu aşağıdaki şekilde ayarlayabilirsiniz:
- Yeni bir TPU sanal makinesi oluşturun
- Kullanmak istediğiniz Jupyter sunucu bağlantı noktası için SSH bağlantı noktası yönlendirmeyi ayarlayın
- Jupyter'i yükleyip TPU sanal makinesinde başlatın, ardından "Yerel bir çalışma zamanına bağlan" aracılığıyla Colab'a bağlanın
Çoklu GPU kurulumuyla ilgili notlar
Bu eğitim TPU kullanım alanına odaklansa da, birden çok GPU'lu bir makineniz varsa bunu kendi ihtiyaçlarınıza göre kolayca uyarlayabilirsiniz.
Colab üzerinden çalışmayı tercih ederseniz Colab için doğrudan Colab Connect menüsündeki "Özel bir GCE sanal makinesine bağlan " üzerinden birden fazla GPU'lu sanal makine de sağlayabilirsiniz.
Burada, Kaggle'ın sunduğu ücretsiz TPU'yu kullanmaya odaklanacağız.
Başlamadan önce
Kaggle kimlik bilgileri
Gemma modelleri Kaggle tarafından barındırılır. Gemma'yı kullanmak için Kaggle'da erişim isteyin:
- kaggle.com adresinde oturum açın veya kaydolun
- Gemma modeli kartını açın ve "Erişim İste"yi seçin
- İzin formunu doldurup şartlar ve koşulları kabul edin
Ardından, Kaggle API'yi kullanmak için bir API jetonu oluşturun:
- Kaggle ayarlarını açın.
- "Yeni Jeton Oluştur"u seçin
- Bir
kaggle.json
dosyası indirildi. Kaggle kimlik bilgilerinizi içerir
Aşağıdaki hücreyi çalıştırın ve istendiğinde Kaggle kimlik bilgilerinizi girin.
# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub
kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…
kagglehub.login() işe yaramazsa alternatif bir yöntem de ortamınızda KAGGLE_USERNAME ve KAGGLE_KEY'i ayarlamaktır.
Döşeme
Gemma modeliyle Keras ve KerasNLP'yi yükleyin.
pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras
Keras JAX arka ucunu kurun
JAX'i içe aktarın ve TPU'da sağlık kontrolü çalıştırın. Kaggle, her biri 16 GB belleğe sahip 8 TPU çekirdeğine sahip TPUv3-8 cihazlar sunar.
import jax
jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os
# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
Modeli yükle
import keras
import keras_nlp
NVIDIA GPU'larla ilgili karma hassasiyetli eğitim hakkında notlar
NVIDIA GPU'larla eğitim yapılırken karma hassasiyet (keras.mixed_precision.set_global_policy('mixed_bfloat16')
) kullanılarak eğitimi hızlandırabilir ve eğitim kalitesini minimum düzeyde etkileyebilirsiniz. Çoğu durumda, hem bellek hem de zaman tasarrufu sağladığı için karma hassasiyetin etkinleştirilmesi önerilir. Bununla birlikte, küçük toplu iş boyutlarında bellek kullanımını 1,5 kat artırabileceğini unutmayın (ağırlıklar yarı hassasiyet ve tam hassasiyetle iki kez yüklenir).
Çıkarım için yarı duyarlık (keras.config.set_floatx("bfloat16")
) çalışır ve karma kesinlik geçerli olmadığında bellekten tasarruf eder.
# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')
Modeli TPU'lara dağıtılmış ağırlıklar ve tensörlerle yüklemek için önce yeni bir DeviceMesh
oluşturun. DeviceMesh
, dağıtılmış hesaplama için yapılandırılmış bir donanım cihazları koleksiyonunu temsil eder ve birleştirilmiş dağıtım API'sinin bir parçası olarak Keras 3'te kullanıma sunulmuştur.
Dağıtım API'si, derin öğrenme modellerinin birden fazla hızlandırıcıda ve ana makinede verimli bir şekilde ölçeklendirilmesine olanak tanıyarak veri ve model benzerliği sağlar. Tek programlı, çoklu veri (SPMD) genişletmesi adı verilen bir prosedür aracılığıyla programı ve tensörleri parçalama yönergelerine göre dağıtmak için temel çerçeveden (ör. JAX) yararlanır. Daha fazla bilgiyi yeni Keras 3 Distribution API kılavuzunda bulabilirsiniz.
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
(1, 8),
["batch", "model"],
devices=keras.distribution.list_devices())
Dağıtım API'sindeki LayoutMap
, dize anahtarları (ör. token_embedding/embeddings
) kullanılarak ağırlıkların ve tensörlerin nasıl parçalanması veya çoğaltılması gerektiğini belirtir. Bu anahtarlar, tensör yollarını eşleştirmek için normal ifade olarak değerlendirilir. Eşleşen tensörler, model boyutlarıyla (8 TPU) parçalanır; diğerleri tamamen çoğaltılır.
model_dim = "model"
layout_map = keras.distribution.LayoutMap(device_mesh)
# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
model_dim, None, None)
layout_map["decoder_block.*attention_output.*kernel"] = (
model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)
ModelParallel
, DeviceMesh
üzerindeki tüm cihazlarda model ağırlıklarını veya etkinleştirme tensörlerini parçalamanıza olanak tanır. Bu durumda, Gemma 7B model ağırlıklarının bazıları yukarıda tanımlanan layout_map
'ye göre 8 TPU çipinde parçalanır. Şimdi modeli dağıtılmış şekilde yükleyin.
model_parallel = keras.distribution.ModelParallel(
device_mesh, layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
Şimdi modelin doğru şekilde bölümlendiğini doğrulayın. Örnek olarak decoder_block_1
inceleyelim.
decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'> decoder_block_1/pre_attention_norm/scale (3072,) PartitionSpec(None,) decoder_block_1/attention/query/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/key/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/value/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/attention_output/kernel (16, 256, 3072) PartitionSpec(None, None, 'model') decoder_block_1/pre_ffw_norm/scale (3072,) PartitionSpec(None,) decoder_block_1/ffw_gating/kernel (3072, 24576) PartitionSpec('model', None) decoder_block_1/ffw_gating_2/kernel (3072, 24576) PartitionSpec('model', None) decoder_block_1/ffw_linear/kernel (24576, 3072) PartitionSpec(None, 'model')
İnce ayardan önce çıkarım
gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'
Model, 90'lardan izlenebilecek en iyi komedi filmlerinin bir listesini oluşturuyor. Şimdi, çıktı stilini değiştirmek için Gemma modeline ince ayar yapıyoruz.
IMDB ile ince ayar
import tensorflow_datasets as tfds
imdb_train = tfds.load(
"imdb_reviews",
split="train",
as_supervised=True,
batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)
imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0... Dl Completed...: 0 url [00:00, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s] Generating train examples...: 0%| | 0/25000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord… Generating test examples...: 0%| | 0/25000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*… Generating unsupervised examples...: 0%| | 0/50000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t… Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data. b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)
Düşük Sıralama Uyarlaması'nı (LoRA) kullanarak ince ayar yapın. LoRA, modelin tam ağırlıklarını dondurarak ve modele daha az sayıda yeni eğitilebilir ağırlık ekleyerek aşağı akış görevleri için eğitilebilir parametre sayısını büyük ölçüde azaltan bir ince ayar tekniğidir. LoRA temelde, daha büyük tam ağırlık matrislerini, eğitilmesi için daha küçük olan 2 düşük sıralı matrisle (AxB) yeniden parametreleştirir. Bu teknik, eğitimi çok daha hızlı ve bellek açısından daha verimli hale getirir.
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.
# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" 2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329 <keras.src.callbacks.history.History at 0x7e9cac7f41c0>
LoRA'yı etkinleştirmenin, eğitilebilir parametre sayısını 7 milyardan 11 milyona önemli ölçüde düşürdüğünü unutmayın.
İnce ayardan sonra çıkarım
gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."
Ayarlamaların ardından model, film eleştirilerinin tarzını öğrendi ve şimdi 90'ların komedi filmleri bağlamında bu tarzda çıktılar üretiyor.
Sırada ne var?
Bu eğiticide, IMDb veri kümesindeki bir Gemma modeline güçlü TPU'larda dağıtılmış bir şekilde ince ayar yapmak için KerasNLP JAX arka ucunun nasıl kullanılacağını öğrendiniz. Öğrenebileceğiniz diğer konularla ilgili birkaç öneri:
- Keras Gemma'yı kullanmaya nasıl başlayacağınızı öğrenin.
- GPU'da Gemma modeline ince ayar yapmayı öğrenin.