Réglage distribué avec Gemma à l'aide de Keras

Afficher sur ai.google.dev Exécuter dans Google Colab Exécuter dans Kaggle Ouvrir dans Vertex AI Consulter le code source sur GitHub

Présentation

Gemma est une famille de modèles ouverts légers et de pointe, élaborés à partir des recherches et des technologies utilisées pour créer des modèles Google Gemini. Gemma peut être affiné davantage pour répondre à des besoins spécifiques. Toutefois, les grands modèles de langage, tels que Gemma, peuvent être très volumineux, et certains d'entre eux ne peuvent pas être optimisés sur un seul accélérateur. Dans ce cas, il existe deux approches générales pour les affiner:

  1. L'affinage avec optimisation des paramètres (PEFT, Parameter Efficient Fine-Tuning), qui vise à réduire la taille effective du modèle en sacrifiant une partie de la fidélité. LoRA fait partie de cette catégorie. Le tutoriel Affiner les modèles Gemma dans Keras à l'aide de LoRA explique comment affiner le modèle Gemma 2B gemma_2b_en avec LoRA à l'aide de KerasNLP sur un seul GPU.
  2. Affinage complet des paramètres avec parallélisme des modèles. Le parallélisme des modèles répartit les poids d'un modèle unique sur plusieurs appareils et permet la mise à l'échelle horizontale. Pour en savoir plus sur l'entraînement distribué, consultez ce guide Keras.

Ce tutoriel vous explique comment utiliser Keras avec un backend JAX pour affiner le modèle Gemma 7B avec LoRA et l'entraînement distribué par parallélisme de modèle sur le TPU (Tensor Processing Unit) de Google. Notez que LoRA peut être désactivé dans ce tutoriel pour un réglage complet des paramètres plus lent, mais plus précis.

Utiliser des accélérateurs

Techniquement, vous pouvez utiliser un TPU ou un GPU pour ce tutoriel.

Remarques sur les environnements TPU

Google propose trois produits qui fournissent des TPU:

  • Colab fournit sans frais TPU v2, ce qui est suffisant pour ce tutoriel.
  • Kaggle propose des TPU v3 sans frais, qui fonctionnent également pour ce tutoriel.
  • Cloud TPU propose les TPU v3 et les générations ultérieures. Vous pouvez le configurer de la manière suivante :
    1. Créer une VM TPU
    2. Configurez le transfert de port SSH pour le port du serveur Jupyter souhaité.
    3. Installez Jupyter et démarrez-le sur la VM TPU, puis connectez-vous à Colab via "Connecter à un environnement d'exécution local".

Remarques sur la configuration multi-GPU

Bien que ce tutoriel se concentre sur le cas d'utilisation des TPU, vous pouvez facilement l'adapter à vos propres besoins si vous disposez d'une machine multi-GPU.

Si vous préférez travailler via Colab, vous pouvez également provisionner une VM multi-GPU pour Colab directement via "Se connecter à une VM GCE personnalisée" dans le menu Colab Connect.

Nous allons nous concentrer sur l'utilisation du TPU sans frais de Kaggle.

Avant de commencer

Identifiants Kaggle

Les modèles Gemma sont hébergés par Kaggle. Pour utiliser Gemma, demandez l'accès sur Kaggle:

  • Connectez-vous ou inscrivez-vous sur kaggle.com.
  • Ouvrez la fiche de modèle Gemma, puis sélectionnez "Demander l'accès".
  • Remplir le formulaire d'autorisation et accepter les conditions d'utilisation

Ensuite, pour utiliser l'API Kaggle, créez un jeton d'API:

  • Ouvrez les paramètres Kaggle.
  • Sélectionnez Create New Token (Créer un jeton).
  • Un fichier kaggle.json est téléchargé. Il contient vos identifiants Kaggle

Exécutez la cellule suivante et saisissez vos identifiants Kaggle lorsque vous y êtes invité.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Une autre méthode consiste à définir KAGGLE_USERNAME et KAGGLE_KEY dans votre environnement si kagglehub.login() ne fonctionne pas pour vous.

Installation

Installez Keras et KerasNLP avec le modèle Gemma.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Configurer le backend Keras JAX

Importez JAX et effectuez une évaluation de l'intégrité sur TPU. Kaggle propose des appareils TPUv3-8 dotés de huit cœurs TPU avec 16 Go de mémoire chacun.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Charger le modèle

import keras
import keras_nlp

Remarques sur l'entraînement de précision mixte sur les GPU NVIDIA

Lors de l'entraînement sur des GPU NVIDIA, la précision mixte (keras.mixed_precision.set_global_policy('mixed_bfloat16')) peut être utilisée pour accélérer l'entraînement avec un impact minimal sur la qualité de l'entraînement. Dans la plupart des cas, nous vous recommandons d'activer la précision mixte, car elle permet d'économiser de la mémoire et du temps. Toutefois, sachez qu'à de petites tailles de lot, il peut gonfler l'utilisation de la mémoire de 1,5 fois (les poids seront chargés deux fois, à demi-précision et à pleine précision).

Pour l'inférence, la demi-précision (keras.config.set_floatx("bfloat16")) fonctionne et économise de la mémoire, tandis que la précision mixte n'est pas applicable.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Pour charger le modèle avec les poids et les tenseurs distribués sur les TPU, créez d'abord un DeviceMesh. DeviceMesh représente un ensemble d'appareils matériels configurés pour le calcul distribué. Il a été introduit dans Keras 3 dans le cadre de l'API de distribution unifiée.

L'API de distribution permet le parallélisme des données et des modèles, ce qui permet de faire évoluer efficacement les modèles de deep learning sur plusieurs accélérateurs et hôtes. Il s'appuie sur le framework sous-jacent (par exemple, JAX) pour distribuer le programme et les tenseurs en fonction des directives de partitionnement via une procédure appelée "expansion SPMD" (single program, multiple data). Pour en savoir plus, consultez le nouveau guide de l'API de distribution Keras 3.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

LayoutMap de l'API de distribution spécifie comment les pondérations et les Tensors doivent être segmentés ou répliqués à l'aide des clés de chaîne (par exemple, token_embedding/embeddings ci-dessous), qui sont traitées comme des expressions régulières pour correspondre aux chemins d'accès des Tensors. Les Tensors mis en correspondance sont segmentés avec les dimensions du modèle (8 TPU). Les autres Tensors sont entièrement répliqués.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel vous permet de diviser les poids du modèle ou les tenseurs d'activation sur tous les appareils du DeviceMesh. Dans ce cas, certains des poids du modèle Gemma 7B sont répartis sur huit puces TPU, conformément à la layout_map définie ci-dessus. Chargez maintenant le modèle de manière distribuée.

model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

Vérifiez maintenant que le modèle a été partitionné correctement. Prenons decoder_block_1 comme exemple.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

Inférence avant l'ajustement

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

Le modèle génère une liste de grands films comiques des années 90 à regarder. Nous allons maintenant affiner le modèle Gemma pour modifier le style de sortie.

Finaliser avec IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

Effectuez des réglages à l'aide de la fonctionnalité LoRA (Low Rank Adaptation). LoRA est une technique d'affinage qui réduit considérablement le nombre de paramètres enregistrables pour les tâches en aval en congelant l'ensemble des poids du modèle et en insérant un plus petit nombre de nouveaux poids enregistrables dans le modèle. En gros, LoRA reparamètre les matrices de poids complètes plus grandes par deux matrices AxB de rang faible plus petites pour l'entraînement. Cette technique rend l'entraînement beaucoup plus rapide et plus économe en mémoire.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

Notez que l'activation de la LoRA réduit considérablement le nombre de paramètres pouvant être entraînés, passant de 7 milliards à seulement 11 millions.

Inférence après ajustement

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

Après avoir été affiné, le modèle a appris le style des critiques de films et génère désormais des résultats dans ce style dans le contexte des comédies des années 90.

Étape suivante

Dans ce tutoriel, vous avez appris à utiliser le backend JAX de KerasNLP pour affiner un modèle Gemma sur l'ensemble de données IMDb de manière distribuée sur les puissants TPU. Voici quelques suggestions pour en savoir plus: