Réglage distribué avec Gemma à l'aide de Keras

Voir sur ai.google.dev Exécuter dans Kaggle Ouvrir dans Vertex AI Afficher la source sur GitHub

Présentation

Gemma est une famille de modèles ouverts légers et ultramodernes, conçus à partir de la recherche et des technologies utilisées pour créer des modèles Google Gemini. Vous pouvez affiner davantage la Gemma en fonction de vos besoins spécifiques. Toutefois, les grands modèles de langage, tels que Gemma, peuvent être de très grande taille, et certains d'entre eux risquent de ne pas pouvoir être optimisés pour le réglage. Dans ce cas, il existe deux approches générales pour les affiner:

  1. L'amélioration efficace des paramètres (PEFT), qui cherche à réduire la taille effective du modèle en sacrifiant une certaine fidélité. LoRA appartient à cette catégorie. Le tutoriel Ajuster les modèles Gemma dans Keras à l'aide de LoRA montre comment affiner le modèle Gemma 2B gemma_2b_en avec LoRA en utilisant KerasNLP sur un seul GPU.
  2. Affinage complet des paramètres avec le parallélisme des modèles. Le parallélisme des modèles répartit les pondérations d'un modèle unique sur plusieurs appareils et permet un scaling horizontal. Pour en savoir plus sur l'entraînement distribué, consultez ce guide Keras.

Ce tutoriel explique comment utiliser Keras avec un backend JAX pour affiner le modèle Gemma 7B avec une LoRA et l'entraînement distribué du parallélisme de modèle sur le Tensor Processing Unit (TPU) de Google. Notez que la fonctionnalité LoRA peut être désactivée dans ce tutoriel pour un réglage complet plus lent, mais plus précis.

Utiliser des accélérateurs

Techniquement, vous pouvez utiliser le TPU ou le GPU pour ce tutoriel.

Remarques sur les environnements TPU

Google propose trois produits qui fournissent des TPU:

  • Colab fournit TPU v2, ce qui n'est pas suffisant pour ce tutoriel.
  • Kaggle propose des TPU v3 sans frais, qui sont compatibles avec ce tutoriel.
  • Cloud TPU propose TPU v3 et les générations plus récentes. Voici une façon de le configurer :
    1. Créer une VM TPU
    2. Configurez le transfert de port SSH pour le port de serveur Jupyter prévu.
    3. Installez Jupyter et démarrez-le sur la VM TPU, puis connectez-vous à Colab via l'option "Se connecter à un environnement d'exécution local".

Remarques sur la configuration multiGPU

Bien que ce tutoriel se concentre sur le cas d'utilisation des TPU, vous pouvez facilement l'adapter à vos propres besoins si vous disposez d'une machine multi-GPU.

Si vous préférez utiliser Colab, vous pouvez également provisionner une VM multi-GPU pour Colab directement via l'option "Se connecter à une VM GCE personnalisée" dans le menu Colab Connect.

Nous allons nous concentrer ici sur l'utilisation du TPU sans frais de Kaggle.

Avant de commencer

Identifiants Kaggle

Les modèles Gemma sont hébergés par Kaggle. Pour utiliser Gemma, demandez l'accès sur Kaggle:

  • Connectez-vous ou inscrivez-vous sur kaggle.com.
  • Ouvrez la fiche du modèle Gemma, puis sélectionnez Demander l'accès.
  • Remplissez le formulaire de consentement et acceptez les conditions d'utilisation

Ensuite, pour utiliser l'API Kaggle, créez un jeton d'API:

  • Ouvrez les paramètres Kaggle.
  • Sélectionnez Create New Token (Créer un jeton).
  • Un fichier kaggle.json est téléchargé. Il contient vos informations d'identification Kaggle

Exécutez la cellule suivante et saisissez vos informations d'identification Kaggle lorsque vous y êtes invité.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Vous pouvez également définir KAGGLE_USERNAME et KAGGLE_KEY dans votre environnement si kagglehub.login() ne fonctionne pas.

Installation

Installer Keras et KerasNLP avec le modèle Gemma

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Configurer le backend Keras JAX

Importez JAX et effectuez un contrôle d'intégrité sur le TPU. Kaggle propose des appareils TPUv3-8 dotés de 8 cœurs TPU dotés chacun de 16 Go de mémoire.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Charger le modèle

import keras
import keras_nlp

Remarques sur l'entraînement de précision mixte sur les GPU NVIDIA

Lorsque vous effectuez un entraînement sur des GPU NVIDIA, vous pouvez utiliser la précision mixte (keras.mixed_precision.set_global_policy('mixed_bfloat16')) pour accélérer l'entraînement avec un impact minimal sur sa qualité. Dans la plupart des cas, nous vous recommandons d'activer la précision mixte, car elle permet d'économiser de la mémoire et du temps. Toutefois, sachez qu'en cas de petite taille de lot, l'utilisation de la mémoire peut être multipliée par 1,5 (les pondérations sont chargées deux fois, avec une demi-précision et une précision maximale).

Pour l'inférence, la demi-précision (keras.config.set_floatx("bfloat16")) fonctionne et permet d'économiser de la mémoire, tandis que la précision mixte n'est pas applicable.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Pour charger le modèle avec des pondérations et des Tensors répartis entre les TPU, vous devez d'abord créer un objet DeviceMesh. DeviceMesh représente un ensemble d'appareils configurés pour le calcul distribué. Il a été introduit dans Keras 3, dans le cadre de l'API de distribution unifiée.

L'API de distribution permet le parallélisme des données et des modèles, ce qui permet un scaling efficace des modèles de deep learning sur plusieurs accélérateurs et hôtes. Il exploite le framework sous-jacent (par exemple, JAX) pour distribuer le programme et les Tensors selon les directives de segmentation, via une procédure appelée expansion SPMD (Single Program, Multiple Data). Pour en savoir plus, consultez le nouveau guide de l'API de distribution Keras 3.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

LayoutMap de l'API de distribution spécifie la manière dont les pondérations et les Tensors doivent être segmentés ou répliqués à l'aide des clés de chaîne (par exemple, token_embedding/embeddings ci-dessous), qui sont traitées comme des expressions régulières pour correspondre aux chemins des Tensors. Les Tensors correspondants sont segmentés en fonction des dimensions du modèle (8 TPU), tandis que les autres sont entièrement répliqués.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel vous permet de segmenter les pondérations de modèle ou les Tensors d'activation sur toutes les plates-formes du DeviceMesh. Dans ce cas, certaines pondérations du modèle Gemma 7B sont réparties sur 8 puces TPU selon l'élément layout_map défini ci-dessus. Chargez maintenant le modèle de manière distribuée.

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

Vérifiez maintenant que le modèle a été partitionné correctement. Prenons decoder_block_1 comme exemple.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

Inférence avant l'affinage

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

Le modèle génère une liste de grands films de comédie des années 90 à regarder. Nous allons maintenant affiner le modèle Gemma pour modifier le style de sortie.

Finetune avec IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

Effectuez le réglage à l'aide de l'adaptation de rang faible (LoRA). LoRA est une technique d'ajustement qui réduit considérablement le nombre de paramètres pouvant être entraînés pour les tâches en aval en gelant la totalité des pondérations du modèle et en y insérant un plus petit nombre de nouvelles pondérations pouvant être entraînées. En gros, la fonction LoRA reparamètre les matrices de pondération complète les plus importantes par 2 matrices de rang faible plus petites (AxB) pour l'entraînement. Cette technique rend l'entraînement beaucoup plus rapide et plus efficace en termes de mémoire.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

Notez que l'activation de la fonctionnalité LoRA permet de réduire considérablement le nombre de paramètres pouvant être entraînés, passant de 7 milliards à seulement 11 millions.

Inférence après réglage

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

Après affinage, le modèle a appris le style des critiques de films et génère maintenant une sortie dans ce style dans le contexte des comédies des années 90.

Étapes suivantes

Dans ce tutoriel, vous avez appris à utiliser le backend KerasNLP JAX pour affiner un modèle Gemma sur l'ensemble de données IMDb de manière distribuée sur des TPU puissants. Voici quelques suggestions d'informations: