Lihat di ai.google.dev | Jalankan di Google Colab | Berjalan di Kaggle | Buka di Vertex AI | Lihat sumber di GitHub |
Ringkasan
Gemma adalah sekumpulan model terbuka yang ringan dan canggih, dibangun dari riset dan teknologi yang digunakan untuk membuat model Google Gemini. Gemma dapat disesuaikan lebih lanjut agar sesuai dengan kebutuhan tertentu. Namun, Model Bahasa Besar, seperti Gemma, dapat berukuran sangat besar dan beberapa di antaranya mungkin tidak muat di akselerator sing untuk melakukan penyesuaian. Dalam kasus ini, ada dua pendekatan umum untuk menyesuaikannya:
- Parameter Efficient Fine-Tuning (PEFT), yang berupaya mengecilkan ukuran model yang efektif dengan mengorbankan sejumlah fidelitas. LoRA termasuk dalam kategori ini dan tutorial Menyesuaikan model Gemma di Keras menggunakan LoRA menunjukkan cara menyesuaikan model Gemma 2B
gemma_2b_en
dengan LoRA menggunakan KerasNLP di satu GPU. - Penyesuaian parameter lengkap dengan paralelisme model. Paralelisme model mendistribusikan bobot satu model di beberapa perangkat dan memungkinkan penskalaan horizontal. Anda dapat mengetahui lebih lanjut pelatihan terdistribusi dalam Panduan Keras ini.
Tutorial ini memandu Anda menggunakan Keras dengan backend JAX untuk meningkatkan kualitas model Gemma 7B dengan LoRA dan pelatihan terdistribusi paralelisme model di Tensor Processing Unit (TPU) Google. Perhatikan bahwa LoRA dapat dinonaktifkan dalam tutorial ini untuk penyesuaian parameter lengkap yang lebih lambat, tetapi lebih akurat.
Menggunakan akselerator
Secara teknis, Anda dapat menggunakan TPU atau GPU untuk tutorial ini.
Catatan tentang lingkungan TPU
Google memiliki 3 produk yang menyediakan TPU:
- Colab menyediakan TPU v2 secara gratis, yang cukup untuk tutorial ini.
- Kaggle menawarkan TPU v3 secara gratis dan juga berfungsi untuk tutorial ini.
- Cloud TPU menawarkan TPU v3 dan generasi yang lebih baru. Salah satu cara menyiapkannya adalah:
- Membuat VM TPU baru
- Menyiapkan penerusan port SSH untuk port server Jupyter yang diinginkan
- Instal Jupyter dan mulai di VM TPU, lalu hubungkan ke Colab melalui "Hubungkan ke runtime lokal"
Catatan tentang penyiapan multi-GPU
Meskipun tutorial ini berfokus pada kasus penggunaan TPU, Anda dapat dengan mudah menyesuaikannya untuk kebutuhan Anda sendiri jika memiliki mesin multi-GPU.
Jika Anda lebih suka bekerja melalui Colab, Anda juga dapat menyediakan VM multi-GPU untuk Colab secara langsung melalui "Hubungkan ke VM GCE kustom" di menu Colab Connect.
Kita akan berfokus pada penggunaan TPU gratis dari Kaggle di sini.
Sebelum memulai
Kredensial Kaggle
Model Gemma dihosting oleh Kaggle. Untuk menggunakan Gemma, minta akses di Kaggle:
- Login atau daftar di kaggle.com
- Buka kartu model Gemma dan pilih "Minta Akses"
- Lengkapi formulir izin dan setujui persyaratan dan ketentuan
Kemudian, untuk menggunakan Kaggle API, buat token API:
- Buka Kaggle settings
- Pilih "Buat Token Baru"
- File
kaggle.json
akan didownload. Berisi kredensial Kaggle Anda
Jalankan sel berikut dan masukkan kredensial Kaggle Anda saat diminta.
# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub
kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…
Cara alternatifnya adalah dengan menetapkan KAGGLE_USERNAME dan KAGGLE_KEY di lingkungan Anda jika kagglehub.login() tidak berfungsi untuk Anda.
Penginstalan
Instal Keras dan KerasNLP dengan model Gemma.
pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras
Menyiapkan backend Keras JAX
Impor JAX dan jalankan pemeriksaan keandalan di TPU. Kaggle menawarkan perangkat TPUv3-8 yang memiliki 8 inti TPU dengan masing-masing memori 16GB.
import jax
jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os
# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
Memuat model
import keras
import keras_nlp
Catatan tentang pelatihan presisi campuran pada GPU NVIDIA
Saat melatih GPU NVIDIA, presisi campuran (keras.mixed_precision.set_global_policy('mixed_bfloat16')
) dapat digunakan untuk mempercepat pelatihan dengan dampak minimal pada kualitas pelatihan. Pada umumnya, sebaiknya aktifkan presisi campuran karena menghemat memori dan waktu. Namun, perlu diketahui bahwa pada ukuran batch yang kecil, hal ini dapat meningkatkan penggunaan memori sebesar 1,5x (bobot akan dimuat dua kali, pada presisi setengah dan presisi penuh).
Untuk inferensi, presisi setengah (keras.config.set_floatx("bfloat16")
) akan berfungsi dan menghemat memori, sedangkan presisi campuran tidak berlaku.
# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')
Untuk memuat model dengan bobot dan tensor yang didistribusikan di seluruh TPU, buat DeviceMesh
baru terlebih dahulu. DeviceMesh
mewakili kumpulan perangkat hardware yang dikonfigurasi untuk komputasi terdistribusi dan diperkenalkan di Keras 3 sebagai bagian dari API distribusi terpadu.
API distribusi memungkinkan paralelisme data dan model, sehingga penskalaan model deep learning secara efisien di berbagai akselerator dan host. Library ini memanfaatkan framework yang mendasarinya (misalnya JAX) untuk mendistribusikan program dan tensor sesuai dengan perintah sharding melalui prosedur yang disebut perluasan single program, multiple data (SPMD). Lihat detail selengkapnya di panduan Keras 3 distribution API yang baru.
# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
(1, 8),
["batch", "model"],
devices=keras.distribution.list_devices())
LayoutMap
dari API distribusi menentukan cara bobot dan tensor harus di-shard atau direplikasi, menggunakan kunci string, misalnya, token_embedding/embeddings
di bawah, yang diperlakukan seperti ekspresi reguler untuk mencocokkan jalur tensor. Tensor yang cocok di-shard dengan dimensi model (8 TPU); yang lainnya akan direplikasi sepenuhnya.
model_dim = "model"
layout_map = keras.distribution.LayoutMap(device_mesh)
# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
model_dim, None, None)
layout_map["decoder_block.*attention_output.*kernel"] = (
model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)
ModelParallel
memungkinkan Anda melakukan shard bobot model atau tensor aktivasi di semua perangkat di DeviceMesh
. Dalam hal ini, beberapa bobot model Gemma 7B di-shard di 8 chip TPU sesuai dengan layout_map
yang ditentukan di atas. Sekarang, muat model dengan cara terdistribusi.
model_parallel = keras.distribution.ModelParallel(
layout_map=layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook... normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
Sekarang, pastikan model telah dipartisi dengan benar. Mari kita ambil decoder_block_1
sebagai contoh.
decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'> decoder_block_1/pre_attention_norm/scale (3072,) PartitionSpec(None,) decoder_block_1/attention/query/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/key/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/value/kernel (16, 3072, 256) PartitionSpec(None, 'model', None) decoder_block_1/attention/attention_output/kernel (16, 256, 3072) PartitionSpec(None, None, 'model') decoder_block_1/pre_ffw_norm/scale (3072,) PartitionSpec(None,) decoder_block_1/ffw_gating/kernel (3072, 24576) PartitionSpec('model', None) decoder_block_1/ffw_gating_2/kernel (3072, 24576) PartitionSpec('model', None) decoder_block_1/ffw_linear/kernel (24576, 3072) PartitionSpec(None, 'model')
Inferen sebelum penyesuaian
gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'
Model ini menghasilkan daftar film komedi terbaik dari tahun 90-an untuk ditonton. Sekarang kita akan menyesuaikan model Gemma untuk mengubah gaya output.
Mengoptimalkan dengan IMDB
import tensorflow_datasets as tfds
imdb_train = tfds.load(
"imdb_reviews",
split="train",
as_supervised=True,
batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)
imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0... Dl Completed...: 0 url [00:00, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s] Generating train examples...: 0%| | 0/25000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord… Generating test examples...: 0%| | 0/25000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*… Generating unsupervised examples...: 0%| | 0/50000 [00:00<?, ? examples/s] Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t… Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data. b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)
Lakukan penyesuaian menggunakan Low Rank Adaptation (LoRA). LoRA adalah teknik fine-tuning yang sangat mengurangi jumlah parameter yang dapat dilatih untuk tugas downstream dengan membekukan bobot penuh model dan memasukkan sejumlah kecil bobot baru yang dapat dilatih ke dalam model. Pada dasarnya, LoRA mereparameterisasi matriks bobot penuh yang lebih besar dengan 2 matriks AxB peringkat rendah yang lebih kecil untuk dilatih dan teknik ini membuat pelatihan jauh lebih cepat dan lebih hemat memori.
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.
# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" 2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329 <keras.src.callbacks.history.History at 0x7e9cac7f41c0>
Perhatikan bahwa mengaktifkan LoRA akan mengurangi jumlah parameter yang dapat dilatih secara signifikan, dari 7 miliar menjadi hanya 11 juta.
Inferensia setelah penyesuaian
gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."
Setelah melakukan penyesuaian, model telah mempelajari gaya ulasan film dan kini menghasilkan output dalam gaya tersebut dalam konteks film komedi tahun 90-an.
Langkah berikutnya
Dalam tutorial ini, Anda telah mempelajari cara menggunakan backend KerasNLP JAX untuk meningkatkan kualitas model Gemma pada set data IMDb secara terdistribusi pada TPU yang andal. Berikut adalah beberapa saran tentang hal lain yang perlu dipelajari:
- Pelajari cara memulai Keras Gemma.
- Pelajari cara menyetel model Gemma di GPU.