Inférence avec Gemma à l'aide de JAX et Flax

Afficher sur ai.google.dev Exécuter dans Google Colab Ouvrir dans Vertex AI Consulter le code source sur GitHub

Présentation

Gemma est une famille de grands modèles de langage ouverts, légers et de pointe, qui s'appuie sur la recherche et la technologie Gemini de Google DeepMind. Ce tutoriel explique comment effectuer un échantillonnage/une inférence de base avec le modèle Gemma 2B Instruct à l'aide de la bibliothèque gemma de Google DeepMind écrite avec JAX (bibliothèque de calcul numérique hautes performances), Flax (bibliothèque de réseaux de neurones basés sur JAX), Orbax (bibliothèque basée sur JAX pour les utilitaires de jetons d'entraînement comme la bibliothèque de jetons d'entraînement) et SentencePiece Bien que le lin ne soit pas utilisé directement dans ce carnet, il a été utilisé pour créer la gemma.

Ce notebook peut être exécuté sur Google Colab avec le GPU T4 sans frais. Accédez à Modifier > Paramètres du notebook > sélectionnez GPU T4 sous Accélérateur matériel.

Configuration

1. Configurer l'accès à Kaggle pour Gemma

Pour suivre ce tutoriel, vous devez d'abord suivre les instructions de la page Configuration de Gemma, qui expliquent comment effectuer les opérations suivantes:

  • Accédez à Gemma sur kaggle.com.
  • Sélectionnez un environnement d'exécution Colab disposant de suffisamment de ressources pour exécuter le modèle Gemma.
  • Générez et configurez un nom d'utilisateur et une clé API Kaggle.

Une fois la configuration de Gemma terminée, passez à la section suivante, dans laquelle vous allez définir des variables d'environnement pour votre environnement Colab.

2. Définir des variables d'environnement

Définissez les variables d'environnement pour KAGGLE_USERNAME et KAGGLE_KEY. Lorsque le message "Accorder l'accès ?" s'affiche, messages, acceptez de fournir un accès secret.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. Installer la bibliothèque gemma

Ce notebook se concentre sur l'utilisation d'un GPU Colab sans frais. Pour activer l'accélération matérielle, cliquez sur Modifier > Paramètres du notebook > Sélectionnez GPU T4 > Enregistrer.

Vous devez ensuite installer la bibliothèque Google DeepMind gemma à partir de github.com/google-deepmind/gemma. Si vous obtenez une erreur concernant le "résolveur de dépendances de pip", vous pouvez généralement l'ignorer.

pip install -q git+https://github.com/google-deepmind/gemma.git

Charger et préparer le modèle Gemma

  1. Chargez le modèle Gemma avec kagglehub.model_download, qui utilise trois arguments:
  • handle: le gestionnaire de modèle de Kaggle
  • path (chaîne facultative) : chemin d'accès local
  • force_download: (valeur booléenne facultative) force le téléchargement à nouveau du modèle.
GEMMA_VARIANT = 'gemma2-2b-it' # @param ['gemma2-2b', 'gemma2-2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma-2/flax/{GEMMA_VARIANT}')
Downloading 11 files:   0%|          | 0/11 [00:00<?, ?it/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/manifest.ocdbt...
100%|██████████| 180/180 [00:00<00:00, 101kB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/d/b5a4695f4be0a2f41ec1e25616ebd7e7...
100%|██████████| 2.66k/2.66k [00:00<00:00, 5.36MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/descriptor/descriptor.pbtxt...
100%|██████████| 45.0/45.0 [00:00<00:00, 90.0kB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/_METADATA...
100%|██████████| 55.3k/55.3k [00:00<00:00, 29.5MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/_CHECKPOINT_METADATA...
100%|██████████| 92.0/92.0 [00:00<00:00, 234kB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/d/bf69258061ae5f35eb7a5669fe6877d4...
0%|          | 0.00/2.12G [00:00<?, ?B/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/d/fc20151969d7ca91ea9d8275bda0e219...
100%|██████████| 2.64k/2.64k [00:00<00:00, 5.58MB/s]

  0%|          | 2.00M/2.12G [00:00<01:48, 20.8MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/ocdbt.process_0/d/834bb4bf1e3854eb09f6208c95c071b2...
0%|          | 0.00/1.70G [00:00<?, ?B/s]
  0%|          | 9.00M/2.12G [00:00<00:46, 48.2MB/s]

  0%|          | 3.00M/1.70G [00:00<01:06, 27.6MB/s]
  1%|          | 14.0M/2.12G [00:00<00:46, 48.6MB/s]

  1%|          | 9.00M/1.70G [00:00<00:40, 44.5MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/manifest.ocdbt...
100%|██████████| 118/118 [00:00<00:00, 303kB/s]

  1%|          | 21.0M/2.12G [00:00<00:41, 53.7MB/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/gemma2-2b-it/checkpoint...
0%|          | 0.00/22.5k [00:00<?, ?B/s]
Downloading from https://www.kaggle.com/api/v1/models/google/gemma-2-2b/flax/gemma2-2b-it/1/download/tokenizer.model...
100%|██████████| 22.5k/22.5k [00:00<00:00, 24.7MB/s]


  1%|          | 17.0M/1.70G [00:00<00:36, 49.5MB/s]


  0%|          | 0.00/4.04M [00:00<?, ?B/s]
100%|██████████| 4.04M/4.04M [00:00<00:00, 64.6MB/s]


  1%|▏         | 24.0M/1.70G [00:00<00:34, 52.7MB/s]
  2%|▏         | 40.0M/2.12G [00:00<00:34, 64.6MB/s]

  2%|▏         | 33.0M/1.70G [00:00<00:27, 64.4MB/s]
  2%|▏         | 49.0M/2.12G [00:00<00:34, 64.8MB/s]

  3%|▎         | 47.0M/1.70G [00:00<00:20, 86.9MB/s]
  3%|▎         | 59.0M/2.12G [00:00<00:29, 74.4MB/s]

  3%|▎         | 56.0M/1.70G [00:00<00:24, 73.1MB/s]
  3%|▎         | 67.0M/2.12G [00:01<00:31, 70.1MB/s]

  4%|▎         | 64.0M/1.70G [00:01<00:25, 69.4MB/s]
  3%|▎         | 74.0M/2.12G [00:01<00:32, 67.4MB/s]

  4%|▍         | 73.0M/1.70G [00:01<00:23, 75.7MB/s]
  4%|▍         | 84.0M/2.12G [00:01<00:28, 75.5MB/s]

  5%|▍         | 81.0M/1.70G [00:01<00:22, 77.7MB/s]

  5%|▌         | 95.0M/1.70G [00:01<00:17, 96.5MB/s]
  4%|▍         | 92.0M/2.12G [00:01<00:38, 56.7MB/s]

  6%|▌         | 106M/1.70G [00:01<00:17, 101MB/s]  
  5%|▍         | 102M/2.12G [00:01<00:32, 67.2MB/s] 

  7%|▋         | 117M/1.70G [00:01<00:16, 102MB/s]
  5%|▌         | 110M/2.12G [00:01<00:30, 70.5MB/s]

  7%|▋         | 128M/1.70G [00:01<00:16, 105MB/s]
  5%|▌         | 119M/2.12G [00:01<00:28, 75.6MB/s]

  8%|▊         | 142M/1.70G [00:01<00:14, 117MB/s]
  6%|▌         | 129M/2.12G [00:02<00:30, 70.1MB/s]

  9%|▉         | 154M/1.70G [00:01<00:17, 92.5MB/s]
  6%|▋         | 138M/2.12G [00:02<00:28, 73.9MB/s]

  9%|▉         | 164M/1.70G [00:02<00:18, 87.7MB/s]
  7%|▋         | 146M/2.12G [00:02<00:30, 70.2MB/s]

 10%|▉         | 173M/1.70G [00:02<00:20, 81.7MB/s]
  7%|▋         | 153M/2.12G [00:02<00:33, 63.0MB/s]

 10%|█         | 182M/1.70G [00:02<00:19, 82.8MB/s]
  8%|▊         | 164M/2.12G [00:02<00:27, 75.3MB/s]

 11%|█         | 195M/1.70G [00:02<00:17, 90.8MB/s]
  8%|▊         | 174M/2.12G [00:02<00:25, 82.0MB/s]

 12%|█▏        | 207M/1.70G [00:02<00:16, 99.0MB/s]
  9%|▊         | 186M/2.12G [00:02<00:22, 92.9MB/s]

 13%|█▎        | 218M/1.70G [00:02<00:16, 99.7MB/s]
  9%|▉         | 196M/2.12G [00:02<00:22, 92.4MB/s]

 13%|█▎        | 229M/1.70G [00:02<00:15, 103MB/s] 
 10%|▉         | 206M/2.12G [00:02<00:22, 92.1MB/s]

 14%|█▎        | 239M/1.70G [00:02<00:15, 99.4MB/s]
 10%|▉         | 215M/2.12G [00:03<00:22, 91.3MB/s]

 14%|█▍        | 250M/1.70G [00:03<00:15, 101MB/s] 
 10%|█         | 226M/2.12G [00:03<00:21, 96.5MB/s]

 15%|█▌        | 263M/1.70G [00:03<00:14, 108MB/s]
 11%|█         | 238M/2.12G [00:03<00:19, 105MB/s] 
 11%|█▏        | 249M/2.12G [00:03<00:19, 103MB/s]

 16%|█▌        | 274M/1.70G [00:03<00:16, 91.6MB/s]
 12%|█▏        | 259M/2.12G [00:03<00:21, 93.3MB/s]

 16%|█▋        | 284M/1.70G [00:03<00:20, 76.0MB/s]
 12%|█▏        | 269M/2.12G [00:03<00:21, 94.3MB/s]

 17%|█▋        | 295M/1.70G [00:03<00:18, 84.0MB/s]
 13%|█▎        | 279M/2.12G [00:03<00:20, 94.2MB/s]

 17%|█▋        | 304M/1.70G [00:03<00:17, 84.4MB/s]
 13%|█▎        | 289M/2.12G [00:03<00:20, 94.9MB/s]

 18%|█▊        | 313M/1.70G [00:03<00:18, 81.9MB/s]
 14%|█▍        | 299M/2.12G [00:03<00:21, 91.4MB/s]
 14%|█▍        | 308M/2.12G [00:04<00:21, 89.4MB/s]

 18%|█▊        | 322M/1.70G [00:03<00:20, 73.5MB/s]

 19%|█▉        | 330M/1.70G [00:04<00:19, 74.9MB/s]
 15%|█▍        | 317M/2.12G [00:04<00:23, 81.1MB/s]
 15%|█▌        | 326M/2.12G [00:04<00:23, 83.6MB/s]

 19%|█▉        | 338M/1.70G [00:04<00:20, 72.0MB/s]

 20%|█▉        | 346M/1.70G [00:04<00:19, 74.6MB/s]
 15%|█▌        | 335M/2.12G [00:04<00:24, 79.2MB/s]

 20%|██        | 354M/1.70G [00:04<00:19, 75.0MB/s]
 16%|█▌        | 344M/2.12G [00:04<00:23, 81.4MB/s]
 16%|█▋        | 352M/2.12G [00:04<00:28, 67.3MB/s]

 21%|██        | 362M/1.70G [00:04<00:26, 54.3MB/s]
 17%|█▋        | 359M/2.12G [00:04<00:31, 59.6MB/s]

 21%|██        | 369M/1.70G [00:04<00:26, 53.4MB/s]
 17%|█▋        | 366M/2.12G [00:05<00:31, 59.0MB/s]

 22%|██▏       | 375M/1.70G [00:04<00:26, 54.9MB/s]
 17%|█▋        | 372M/2.12G [00:05<00:31, 59.2MB/s]

 22%|██▏       | 381M/1.70G [00:05<00:25, 56.2MB/s]
 17%|█▋        | 379M/2.12G [00:05<00:30, 62.3MB/s]

 22%|██▏       | 388M/1.70G [00:05<00:24, 56.8MB/s]
 18%|█▊        | 386M/2.12G [00:05<00:29, 63.8MB/s]

 23%|██▎       | 395M/1.70G [00:05<00:23, 60.2MB/s]
 18%|█▊        | 394M/2.12G [00:05<00:27, 68.5MB/s]

 23%|██▎       | 402M/1.70G [00:05<00:22, 62.7MB/s]
 19%|█▊        | 401M/2.12G [00:05<00:27, 66.3MB/s]

 23%|██▎       | 409M/1.70G [00:05<00:21, 64.4MB/s]
 19%|█▉        | 408M/2.12G [00:05<00:28, 65.4MB/s]

 24%|██▍       | 416M/1.70G [00:05<00:21, 65.4MB/s]

 24%|██▍       | 423M/1.70G [00:05<00:26, 51.4MB/s]
 19%|█▉        | 415M/2.12G [00:07<03:02, 10.1MB/s]

 25%|██▍       | 429M/1.70G [00:08<02:56, 7.79MB/s]
 19%|█▉        | 420M/2.12G [00:08<03:17, 9.28MB/s]

 25%|██▌       | 439M/1.70G [00:08<01:52, 12.2MB/s]
 20%|█▉        | 432M/2.12G [00:08<01:56, 15.7MB/s]

 26%|██▌       | 447M/1.70G [00:08<01:22, 16.5MB/s]
 20%|██        | 441M/2.12G [00:08<01:25, 21.1MB/s]

 26%|██▌       | 454M/1.70G [00:08<01:05, 20.5MB/s]
 21%|██        | 448M/2.12G [00:09<01:14, 24.0MB/s]

 26%|██▋       | 460M/1.70G [00:08<00:54, 24.4MB/s]

 27%|██▋       | 468M/1.70G [00:09<00:42, 31.6MB/s]
 21%|██        | 454M/2.12G [00:09<01:07, 26.4MB/s]
 21%|██▏       | 464M/2.12G [00:09<00:49, 36.4MB/s]

 27%|██▋       | 476M/1.70G [00:09<00:38, 34.4MB/s]

 28%|██▊       | 487M/1.70G [00:09<00:28, 47.0MB/s]
 22%|██▏       | 471M/2.12G [00:09<00:53, 33.2MB/s]

 28%|██▊       | 495M/1.70G [00:09<00:28, 46.0MB/s]
 22%|██▏       | 477M/2.12G [00:09<00:49, 36.0MB/s]

 29%|██▉       | 502M/1.70G [00:09<00:27, 47.6MB/s]

 29%|██▉       | 510M/1.70G [00:09<00:23, 54.1MB/s]
 22%|██▏       | 483M/2.12G [00:09<00:52, 33.9MB/s]

 30%|██▉       | 519M/1.70G [00:09<00:20, 62.0MB/s]
 23%|██▎       | 491M/2.12G [00:10<00:41, 41.9MB/s]

 30%|███       | 527M/1.70G [00:09<00:19, 65.6MB/s]
 23%|██▎       | 497M/2.12G [00:10<00:47, 37.1MB/s]
 23%|██▎       | 506M/2.12G [00:10<00:36, 47.1MB/s]

 31%|███       | 535M/1.70G [00:10<00:26, 47.1MB/s]
 24%|██▎       | 513M/2.12G [00:10<00:35, 49.5MB/s]

 31%|███       | 541M/1.70G [00:10<00:25, 49.2MB/s]
 24%|██▍       | 523M/2.12G [00:10<00:28, 60.8MB/s]

 32%|███▏      | 551M/1.70G [00:10<00:20, 60.3MB/s]
 24%|██▍       | 530M/2.12G [00:10<00:30, 56.8MB/s]

 32%|███▏      | 561M/1.70G [00:10<00:18, 65.7MB/s]
 25%|██▍       | 537M/2.12G [00:10<00:29, 58.3MB/s]

 33%|███▎      | 569M/1.70G [00:10<00:18, 67.2MB/s]
 25%|██▌       | 547M/2.12G [00:10<00:24, 68.2MB/s]

 33%|███▎      | 578M/1.70G [00:10<00:16, 73.2MB/s]
 26%|██▌       | 557M/2.12G [00:11<00:21, 77.0MB/s]

 34%|███▎      | 586M/1.70G [00:10<00:17, 71.0MB/s]

 34%|███▍      | 595M/1.70G [00:11<00:15, 76.1MB/s]
 26%|██▌       | 565M/2.12G [00:11<00:24, 69.2MB/s]

 35%|███▍      | 609M/1.70G [00:11<00:13, 88.9MB/s]
 26%|██▋       | 573M/2.12G [00:11<00:26, 63.3MB/s]

 35%|███▌      | 618M/1.70G [00:11<00:13, 90.0MB/s]
 27%|██▋       | 583M/2.12G [00:11<00:23, 71.8MB/s]

 36%|███▌      | 630M/1.70G [00:11<00:11, 99.5MB/s]

 37%|███▋      | 640M/1.70G [00:12<00:37, 31.0MB/s]
 27%|██▋       | 591M/2.12G [00:12<01:09, 23.8MB/s]

 37%|███▋      | 650M/1.70G [00:12<00:29, 38.7MB/s]
 28%|██▊       | 602M/2.12G [00:12<00:49, 32.9MB/s]
 28%|██▊       | 611M/2.12G [00:12<00:40, 40.3MB/s]

 38%|███▊      | 660M/1.70G [00:12<00:24, 46.6MB/s]

 39%|███▊      | 673M/1.70G [00:12<00:18, 60.3MB/s]
 29%|██▊       | 619M/2.12G [00:12<00:37, 43.2MB/s]

 39%|███▉      | 684M/1.70G [00:12<00:15, 69.5MB/s]
 29%|██▉       | 626M/2.12G [00:12<00:35, 45.0MB/s]

 40%|████      | 697M/1.70G [00:12<00:13, 79.8MB/s]
 29%|██▉       | 638M/2.12G [00:13<00:27, 59.2MB/s]

 41%|████      | 707M/1.70G [00:12<00:12, 83.7MB/s]
 30%|██▉       | 646M/2.12G [00:13<00:25, 63.4MB/s]

 41%|████      | 717M/1.70G [00:13<00:12, 88.5MB/s]
 30%|███       | 654M/2.12G [00:13<00:23, 67.5MB/s]
 31%|███       | 662M/2.12G [00:13<00:22, 70.8MB/s]

 42%|████▏     | 727M/1.70G [00:13<00:15, 67.7MB/s]

 42%|████▏     | 736M/1.70G [00:13<00:15, 68.8MB/s]
 31%|███       | 670M/2.12G [00:13<00:26, 58.5MB/s]

 43%|████▎     | 744M/1.70G [00:13<00:15, 67.3MB/s]
 31%|███▏      | 677M/2.12G [00:13<00:29, 53.2MB/s]

 43%|████▎     | 755M/1.70G [00:13<00:13, 77.3MB/s]
 32%|███▏      | 683M/2.12G [00:13<00:28, 54.4MB/s]

 44%|████▍     | 765M/1.70G [00:13<00:12, 83.9MB/s]
 32%|███▏      | 690M/2.12G [00:13<00:26, 58.6MB/s]

 44%|████▍     | 774M/1.70G [00:13<00:13, 77.1MB/s]
 32%|███▏      | 703M/2.12G [00:14<00:19, 77.2MB/s]

 45%|████▌     | 786M/1.70G [00:14<00:11, 88.4MB/s]
 33%|███▎      | 712M/2.12G [00:14<00:20, 73.5MB/s]

 46%|████▌     | 797M/1.70G [00:14<00:10, 94.8MB/s]
 33%|███▎      | 722M/2.12G [00:14<00:19, 79.4MB/s]

 46%|████▋     | 807M/1.70G [00:14<00:10, 93.8MB/s]
 34%|███▍      | 731M/2.12G [00:14<00:18, 81.6MB/s]

 47%|████▋     | 817M/1.70G [00:14<00:10, 90.0MB/s]
 34%|███▍      | 740M/2.12G [00:14<00:17, 83.8MB/s]

 48%|████▊     | 829M/1.70G [00:14<00:09, 98.2MB/s]
 35%|███▍      | 749M/2.12G [00:14<00:17, 84.2MB/s]

 48%|████▊     | 839M/1.70G [00:14<00:09, 95.3MB/s]
 35%|███▌      | 759M/2.12G [00:14<00:16, 89.8MB/s]
 36%|███▌      | 769M/2.12G [00:14<00:15, 93.3MB/s]

 49%|████▉     | 849M/1.70G [00:14<00:09, 93.8MB/s]
 36%|███▌      | 780M/2.12G [00:14<00:14, 97.7MB/s]

 49%|████▉     | 859M/1.70G [00:14<00:10, 91.0MB/s]
 37%|███▋      | 793M/2.12G [00:15<00:13, 106MB/s] 

 50%|████▉     | 868M/1.70G [00:14<00:10, 89.9MB/s]
 37%|███▋      | 804M/2.12G [00:15<00:13, 107MB/s]

 50%|█████     | 877M/1.70G [00:15<00:10, 87.0MB/s]

 51%|█████     | 886M/1.70G [00:15<00:10, 85.0MB/s]
 38%|███▊      | 815M/2.12G [00:15<00:16, 84.8MB/s]

 51%|█████▏    | 895M/1.70G [00:15<00:12, 69.5MB/s]
 38%|███▊      | 824M/2.12G [00:15<00:18, 74.1MB/s]

 52%|█████▏    | 904M/1.70G [00:15<00:11, 73.7MB/s]
 38%|███▊      | 832M/2.12G [00:15<00:18, 75.1MB/s]

 52%|█████▏    | 912M/1.70G [00:15<00:11, 75.6MB/s]
 39%|███▉      | 843M/2.12G [00:15<00:16, 83.7MB/s]
 40%|███▉      | 856M/2.12G [00:15<00:14, 95.7MB/s]

 53%|█████▎    | 921M/1.70G [00:15<00:12, 71.1MB/s]
 40%|███▉      | 866M/2.12G [00:15<00:13, 97.5MB/s]

 53%|█████▎    | 931M/1.70G [00:15<00:10, 77.9MB/s]
 41%|████      | 878M/2.12G [00:16<00:12, 104MB/s] 

 54%|█████▍    | 939M/1.70G [00:16<00:11, 70.3MB/s]
 41%|████      | 889M/2.12G [00:16<00:12, 104MB/s]

 55%|█████▍    | 950M/1.70G [00:16<00:10, 80.9MB/s]

 56%|█████▌    | 967M/1.70G [00:16<00:07, 105MB/s] 
 42%|████▏     | 900M/2.12G [00:16<00:17, 73.9MB/s]

 56%|█████▌    | 978M/1.70G [00:16<00:07, 105MB/s]
 42%|████▏     | 909M/2.12G [00:16<00:17, 76.9MB/s]

 57%|█████▋    | 989M/1.70G [00:16<00:07, 103MB/s]
 43%|████▎     | 921M/2.12G [00:16<00:14, 87.9MB/s]

 57%|█████▋    | 0.98G/1.70G [00:16<00:07, 105MB/s]

 58%|█████▊    | 0.99G/1.70G [00:16<00:06, 110MB/s]
 43%|████▎     | 931M/2.12G [00:16<00:15, 81.4MB/s]
 43%|████▎     | 940M/2.12G [00:16<00:15, 82.6MB/s]

 59%|█████▊    | 1.00G/1.70G [00:16<00:09, 76.3MB/s]
 44%|████▍     | 949M/2.12G [00:17<00:18, 70.7MB/s]
 44%|████▍     | 957M/2.12G [00:17<00:18, 67.3MB/s]

 59%|█████▉    | 1.01G/1.70G [00:17<00:11, 64.5MB/s]
 45%|████▍     | 964M/2.12G [00:17<00:19, 65.9MB/s]

 60%|█████▉    | 1.02G/1.70G [00:17<00:11, 64.5MB/s]
 45%|████▍     | 971M/2.12G [00:17<00:19, 65.2MB/s]

 60%|██████    | 1.02G/1.70G [00:17<00:11, 63.6MB/s]
 45%|████▌     | 978M/2.12G [00:17<00:19, 63.6MB/s]

 61%|██████    | 1.03G/1.70G [00:17<00:11, 62.9MB/s]
 45%|████▌     | 985M/2.12G [00:17<00:19, 64.7MB/s]

 61%|██████    | 1.04G/1.70G [00:17<00:11, 62.8MB/s]
 46%|████▌     | 992M/2.12G [00:17<00:18, 65.3MB/s]

 61%|██████▏   | 1.04G/1.70G [00:17<00:10, 64.5MB/s]
 46%|████▌     | 0.98G/2.12G [00:17<00:17, 68.9MB/s]

 62%|██████▏   | 1.05G/1.70G [00:17<00:10, 66.5MB/s]
 46%|████▋     | 0.98G/2.12G [00:18<00:17, 69.8MB/s]

 62%|██████▏   | 1.06G/1.70G [00:17<00:10, 67.7MB/s]
 47%|████▋     | 0.99G/2.12G [00:18<00:16, 73.2MB/s]

 63%|██████▎   | 1.07G/1.70G [00:18<00:09, 69.0MB/s]
 47%|████▋     | 1.00G/2.12G [00:18<00:16, 72.2MB/s]

 63%|██████▎   | 1.07G/1.70G [00:18<00:09, 68.5MB/s]
 48%|████▊     | 1.01G/2.12G [00:18<00:16, 73.8MB/s]

 63%|██████▎   | 1.08G/1.70G [00:18<00:09, 69.1MB/s]
 48%|████▊     | 1.01G/2.12G [00:18<00:18, 63.5MB/s]

 64%|██████▍   | 1.09G/1.70G [00:18<00:10, 61.4MB/s]
 48%|████▊     | 1.02G/2.12G [00:18<00:20, 57.1MB/s]

 64%|██████▍   | 1.09G/1.70G [00:18<00:11, 55.9MB/s]
 49%|████▊     | 1.03G/2.12G [00:18<00:20, 57.9MB/s]

 65%|██████▍   | 1.10G/1.70G [00:18<00:11, 56.5MB/s]
 49%|████▉     | 1.04G/2.12G [00:18<00:18, 63.5MB/s]

 65%|██████▍   | 1.10G/1.70G [00:18<00:11, 57.4MB/s]
 49%|████▉     | 1.04G/2.12G [00:19<00:18, 63.4MB/s]

 65%|██████▌   | 1.11G/1.70G [00:18<00:10, 58.5MB/s]
 50%|████▉     | 1.05G/2.12G [00:19<00:17, 64.7MB/s]

 66%|██████▌   | 1.12G/1.70G [00:19<00:10, 61.8MB/s]
 50%|████▉     | 1.06G/2.12G [00:19<00:16, 67.6MB/s]

 66%|██████▌   | 1.12G/1.70G [00:19<00:09, 63.1MB/s]
 50%|█████     | 1.06G/2.12G [00:19<00:15, 70.6MB/s]

 67%|██████▋   | 1.13G/1.70G [00:19<00:08, 69.1MB/s]
 51%|█████     | 1.07G/2.12G [00:19<00:15, 72.6MB/s]

 67%|██████▋   | 1.14G/1.70G [00:19<00:08, 71.4MB/s]
 51%|█████     | 1.08G/2.12G [00:19<00:15, 73.6MB/s]

 68%|██████▊   | 1.15G/1.70G [00:19<00:07, 74.3MB/s]
 51%|█████▏    | 1.09G/2.12G [00:19<00:14, 76.0MB/s]

 68%|██████▊   | 1.16G/1.70G [00:19<00:07, 74.0MB/s]
 52%|█████▏    | 1.10G/2.12G [00:19<00:15, 69.1MB/s]

 69%|██████▊   | 1.17G/1.70G [00:19<00:08, 69.9MB/s]
 52%|█████▏    | 1.10G/2.12G [00:19<00:15, 70.8MB/s]

 69%|██████▉   | 1.17G/1.70G [00:19<00:08, 70.0MB/s]
 52%|█████▏    | 1.11G/2.12G [00:20<00:16, 66.4MB/s]

 69%|██████▉   | 1.18G/1.70G [00:19<00:09, 60.2MB/s]

 70%|██████▉   | 1.19G/1.70G [00:20<00:08, 62.5MB/s]
 53%|█████▎    | 1.12G/2.12G [00:20<00:19, 54.8MB/s]

 70%|███████   | 1.19G/1.70G [00:20<00:08, 63.8MB/s]
 53%|█████▎    | 1.12G/2.12G [00:20<00:18, 57.7MB/s]

 71%|███████   | 1.20G/1.70G [00:20<00:08, 64.8MB/s]
 53%|█████▎    | 1.13G/2.12G [00:20<00:17, 61.3MB/s]

 71%|███████   | 1.21G/1.70G [00:20<00:07, 68.3MB/s]
 54%|█████▍    | 1.14G/2.12G [00:20<00:16, 65.0MB/s]

 71%|███████▏  | 1.21G/1.70G [00:20<00:07, 66.1MB/s]
 54%|█████▍    | 1.15G/2.12G [00:20<00:16, 64.4MB/s]

 72%|███████▏  | 1.22G/1.70G [00:20<00:07, 66.4MB/s]
 54%|█████▍    | 1.15G/2.12G [00:20<00:16, 63.0MB/s]

 72%|███████▏  | 1.23G/1.70G [00:20<00:08, 61.2MB/s]
 55%|█████▍    | 1.16G/2.12G [00:20<00:16, 61.6MB/s]

 73%|███████▎  | 1.23G/1.70G [00:20<00:08, 58.0MB/s]
 55%|█████▌    | 1.17G/2.12G [00:21<00:16, 61.0MB/s]
 55%|█████▌    | 1.17G/2.12G [00:22<01:11, 14.1MB/s]

 73%|███████▎  | 1.24G/1.70G [00:22<00:37, 13.4MB/s]
 56%|█████▌    | 1.18G/2.12G [00:22<00:51, 19.6MB/s]

 73%|███████▎  | 1.25G/1.70G [00:22<00:23, 20.3MB/s]
 56%|█████▌    | 1.19G/2.12G [00:22<00:42, 23.4MB/s]

 74%|███████▍  | 1.25G/1.70G [00:22<00:19, 24.2MB/s]

 74%|███████▍  | 1.26G/1.70G [00:22<00:13, 34.0MB/s]

 75%|███████▌  | 1.28G/1.70G [00:22<00:08, 50.6MB/s]

 76%|███████▌  | 1.29G/1.70G [00:22<00:07, 59.8MB/s]
 57%|█████▋    | 1.20G/2.12G [00:23<00:35, 27.6MB/s]
 57%|█████▋    | 1.21G/2.12G [00:23<00:26, 36.3MB/s]

 76%|███████▋  | 1.30G/1.70G [00:23<00:07, 57.6MB/s]

 77%|███████▋  | 1.31G/1.70G [00:23<00:06, 62.4MB/s]
 57%|█████▋    | 1.21G/2.12G [00:23<00:27, 35.1MB/s]
 58%|█████▊    | 1.22G/2.12G [00:23<00:24, 38.9MB/s]

 77%|███████▋  | 1.31G/1.70G [00:23<00:07, 53.2MB/s]
 58%|█████▊    | 1.23G/2.12G [00:23<00:22, 43.0MB/s]

 78%|███████▊  | 1.32G/1.70G [00:23<00:06, 61.9MB/s]

 78%|███████▊  | 1.33G/1.70G [00:23<00:05, 68.7MB/s]
 58%|█████▊    | 1.24G/2.12G [00:23<00:20, 46.8MB/s]

 79%|███████▉  | 1.34G/1.70G [00:23<00:05, 72.2MB/s]
 59%|█████▉    | 1.24G/2.12G [00:23<00:20, 46.0MB/s]

 80%|███████▉  | 1.35G/1.70G [00:23<00:04, 77.2MB/s]

 80%|███████▉  | 1.36G/1.70G [00:23<00:04, 74.6MB/s]
 59%|█████▉    | 1.25G/2.12G [00:24<00:19, 47.6MB/s]

 81%|████████  | 1.37G/1.70G [00:24<00:04, 79.9MB/s]
 60%|█████▉    | 1.26G/2.12G [00:24<00:16, 56.2MB/s]

 81%|████████  | 1.38G/1.70G [00:24<00:03, 88.6MB/s]
 60%|█████▉    | 1.27G/2.12G [00:24<00:18, 48.6MB/s]

 82%|████████▏ | 1.39G/1.70G [00:24<00:03, 87.9MB/s]
 60%|██████    | 1.27G/2.12G [00:24<00:16, 55.1MB/s]

 82%|████████▏ | 1.40G/1.70G [00:24<00:03, 81.4MB/s]
 61%|██████    | 1.28G/2.12G [00:24<00:14, 62.4MB/s]

 83%|████████▎ | 1.41G/1.70G [00:24<00:03, 82.2MB/s]
 61%|██████    | 1.29G/2.12G [00:24<00:14, 62.8MB/s]

 83%|████████▎ | 1.42G/1.70G [00:24<00:03, 85.9MB/s]
 61%|██████▏   | 1.30G/2.12G [00:24<00:13, 64.0MB/s]

 84%|████████▍ | 1.43G/1.70G [00:24<00:02, 99.6MB/s]

 85%|████████▍ | 1.45G/1.70G [00:24<00:02, 109MB/s] 
 62%|██████▏   | 1.31G/2.12G [00:25<00:15, 57.2MB/s]

 86%|████████▌ | 1.46G/1.70G [00:25<00:02, 110MB/s]
 62%|██████▏   | 1.32G/2.12G [00:25<00:12, 69.9MB/s]

 86%|████████▌ | 1.47G/1.70G [00:25<00:02, 110MB/s]
 63%|██████▎   | 1.33G/2.12G [00:25<00:11, 74.1MB/s]

 87%|████████▋ | 1.48G/1.70G [00:25<00:02, 102MB/s]
 63%|██████▎   | 1.34G/2.12G [00:25<00:10, 82.6MB/s]

 87%|████████▋ | 1.49G/1.70G [00:25<00:02, 96.1MB/s]
 64%|██████▎   | 1.34G/2.12G [00:25<00:09, 85.4MB/s]

 88%|████████▊ | 1.50G/1.70G [00:25<00:02, 106MB/s] 
 64%|██████▍   | 1.35G/2.12G [00:25<00:11, 71.7MB/s]

 89%|████████▉ | 1.52G/1.70G [00:25<00:01, 119MB/s]
 65%|██████▍   | 1.37G/2.12G [00:25<00:09, 84.5MB/s]

 90%|████████▉ | 1.53G/1.70G [00:25<00:01, 117MB/s]
 65%|██████▍   | 1.37G/2.12G [00:25<00:11, 69.7MB/s]

 91%|█████████ | 1.54G/1.70G [00:25<00:01, 99.4MB/s]
 65%|██████▌   | 1.38G/2.12G [00:26<00:10, 74.7MB/s]

 91%|█████████ | 1.55G/1.70G [00:25<00:01, 99.8MB/s]
 66%|██████▌   | 1.39G/2.12G [00:26<00:10, 74.7MB/s]

 92%|█████████▏| 1.56G/1.70G [00:26<00:01, 102MB/s] 
 66%|██████▋   | 1.40G/2.12G [00:26<00:09, 83.9MB/s]

 93%|█████████▎| 1.57G/1.70G [00:26<00:01, 108MB/s]
 67%|██████▋   | 1.41G/2.12G [00:26<00:08, 91.4MB/s]

 93%|█████████▎| 1.58G/1.70G [00:26<00:01, 100MB/s]
 67%|██████▋   | 1.42G/2.12G [00:26<00:08, 90.9MB/s]

 94%|█████████▎| 1.59G/1.70G [00:26<00:01, 87.1MB/s]
 68%|██████▊   | 1.43G/2.12G [00:26<00:08, 86.6MB/s]

 94%|█████████▍| 1.60G/1.70G [00:26<00:01, 82.8MB/s]
 68%|██████▊   | 1.44G/2.12G [00:26<00:10, 70.8MB/s]

 95%|█████████▍| 1.61G/1.70G [00:26<00:01, 78.2MB/s]
 68%|██████▊   | 1.45G/2.12G [00:26<00:09, 73.6MB/s]

 95%|█████████▌| 1.62G/1.70G [00:26<00:01, 83.3MB/s]
 69%|██████▉   | 1.46G/2.12G [00:27<00:08, 80.8MB/s]

 96%|█████████▌| 1.63G/1.70G [00:26<00:00, 85.3MB/s]
 69%|██████▉   | 1.47G/2.12G [00:27<00:08, 81.4MB/s]

 96%|█████████▋| 1.64G/1.70G [00:27<00:00, 87.5MB/s]
 70%|██████▉   | 1.48G/2.12G [00:27<00:07, 86.9MB/s]

 97%|█████████▋| 1.65G/1.70G [00:27<00:00, 70.2MB/s]
 70%|███████   | 1.48G/2.12G [00:27<00:10, 67.4MB/s]

 97%|█████████▋| 1.66G/1.70G [00:27<00:00, 75.5MB/s]
 71%|███████   | 1.49G/2.12G [00:27<00:09, 71.1MB/s]

 98%|█████████▊| 1.67G/1.70G [00:27<00:00, 82.3MB/s]
 71%|███████   | 1.50G/2.12G [00:27<00:08, 74.9MB/s]

 99%|█████████▊| 1.68G/1.70G [00:27<00:00, 90.1MB/s]
 71%|███████▏  | 1.51G/2.12G [00:27<00:08, 75.7MB/s]

 99%|█████████▉| 1.69G/1.70G [00:27<00:00, 89.7MB/s]
 72%|███████▏  | 1.52G/2.12G [00:27<00:08, 79.7MB/s]

100%|██████████| 1.70G/1.70G [00:27<00:00, 65.5MB/s]

 72%|███████▏  | 1.53G/2.12G [00:28<00:08, 76.6MB/s]
 73%|███████▎  | 1.54G/2.12G [00:28<00:06, 98.0MB/s]
 74%|███████▎  | 1.56G/2.12G [00:28<00:05, 112MB/s] 
 74%|███████▍  | 1.57G/2.12G [00:28<00:05, 105MB/s]
 75%|███████▍  | 1.58G/2.12G [00:28<00:05, 107MB/s]
 75%|███████▌  | 1.59G/2.12G [00:28<00:06, 89.3MB/s]
 76%|███████▌  | 1.61G/2.12G [00:28<00:05, 101MB/s] 
 77%|███████▋  | 1.62G/2.12G [00:28<00:04, 112MB/s]
 77%|███████▋  | 1.63G/2.12G [00:29<00:05, 92.3MB/s]
 78%|███████▊  | 1.64G/2.12G [00:29<00:06, 78.4MB/s]
 78%|███████▊  | 1.66G/2.12G [00:29<00:05, 95.5MB/s]
 79%|███████▉  | 1.67G/2.12G [00:29<00:04, 101MB/s] 
 80%|███████▉  | 1.69G/2.12G [00:29<00:04, 114MB/s]
 80%|████████  | 1.70G/2.12G [00:29<00:03, 124MB/s]
 81%|████████  | 1.71G/2.12G [00:29<00:03, 122MB/s]
 82%|████████▏ | 1.73G/2.12G [00:30<00:04, 95.5MB/s]
 82%|████████▏ | 1.74G/2.12G [00:30<00:03, 104MB/s] 
 83%|████████▎ | 1.75G/2.12G [00:30<00:03, 103MB/s]
 83%|████████▎ | 1.76G/2.12G [00:30<00:04, 76.8MB/s]
 84%|████████▍ | 1.78G/2.12G [00:30<00:03, 96.8MB/s]
 85%|████████▍ | 1.79G/2.12G [00:30<00:03, 102MB/s] 
 85%|████████▌ | 1.80G/2.12G [00:31<00:04, 69.5MB/s]
 86%|████████▌ | 1.82G/2.12G [00:31<00:03, 92.2MB/s]
 87%|████████▋ | 1.83G/2.12G [00:31<00:03, 91.5MB/s]
 87%|████████▋ | 1.84G/2.12G [00:31<00:02, 98.1MB/s]
 88%|████████▊ | 1.85G/2.12G [00:31<00:03, 85.1MB/s]
 88%|████████▊ | 1.86G/2.12G [00:31<00:03, 88.0MB/s]
 89%|████████▊ | 1.87G/2.12G [00:31<00:02, 88.9MB/s]
 89%|████████▉ | 1.88G/2.12G [00:32<00:02, 93.5MB/s]
 90%|████████▉ | 1.90G/2.12G [00:32<00:02, 106MB/s] 
 90%|█████████ | 1.91G/2.12G [00:32<00:01, 111MB/s]
 91%|█████████ | 1.92G/2.12G [00:32<00:02, 98.7MB/s]
 91%|█████████▏| 1.93G/2.12G [00:32<00:01, 107MB/s] 
 92%|█████████▏| 1.95G/2.12G [00:32<00:01, 104MB/s]
 93%|█████████▎| 1.96G/2.12G [00:32<00:01, 117MB/s]
 93%|█████████▎| 1.97G/2.12G [00:32<00:01, 106MB/s]
 94%|█████████▍| 1.98G/2.12G [00:33<00:01, 92.6MB/s]
 94%|█████████▍| 1.99G/2.12G [00:33<00:01, 86.4MB/s]
 95%|█████████▍| 2.00G/2.12G [00:33<00:01, 68.3MB/s]
 95%|█████████▌| 2.02G/2.12G [00:33<00:01, 84.0MB/s]
 96%|█████████▌| 2.03G/2.12G [00:33<00:01, 91.6MB/s]
 96%|█████████▋| 2.04G/2.12G [00:33<00:00, 96.2MB/s]
 97%|█████████▋| 2.05G/2.12G [00:33<00:00, 108MB/s] 
 98%|█████████▊| 2.06G/2.12G [00:33<00:00, 89.9MB/s]
 98%|█████████▊| 2.08G/2.12G [00:34<00:00, 103MB/s] 
 99%|█████████▉| 2.09G/2.12G [00:34<00:00, 115MB/s]
100%|██████████| 2.12G/2.12G [00:34<00:00, 66.0MB/s]
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma-2-2b/flax/gemma2-2b-it/1
  1. Vérifiez l'emplacement des pondérations du modèle et de la fonction de tokenisation, puis définissez les variables de chemin. Le répertoire de tokenisation se trouve dans le répertoire principal dans lequel vous avez téléchargé le modèle, tandis que les pondérations du modèle sont dans un sous-répertoire. Exemple :
  • Le fichier tokenizer.model se trouvera dans /LOCAL/PATH/TO/gemma/flax/2b-it/2).
  • Le point de contrôle du modèle se trouvera dans /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it).
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma-2-2b/flax/gemma2-2b-it/1/gemma2-2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma-2-2b/flax/gemma2-2b-it/1/tokenizer.model

Effectuer un échantillonnage/une inférence

  1. Chargez et formatez le point de contrôle du modèle Gemma à l'aide de la méthode gemma.params.load_and_format_params:
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)
  1. Chargez la fonction de tokenisation Gemma, créée à l'aide de sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Pour charger automatiquement la configuration appropriée à partir du point de contrôle du modèle Gemma, utilisez gemma.transformer.TransformerConfig. L'argument cache_size correspond au nombre de pas de temps dans le cache Transformer Gemma. Ensuite, instanciez le modèle Gemma en tant que transformer avec gemma.transformer.Transformer (qui hérite de flax.linen.Module).
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)
  1. Créez une sampler avec gemma.sampler.Sampler au-dessus du point de contrôle/des pondérations du modèle Gemma et de la fonction de tokenisation:
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)
  1. Écrivez une requête dans input_batch et effectuez l'inférence. Vous pouvez modifier total_generation_steps (le nombre d'étapes effectuées lors de la génération d'une réponse. Cet exemple utilise 100 pour préserver la mémoire de l'hôte).
prompt = [
    "what is JAX in 3 bullet points?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=128,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
Prompt:
what is JAX in 3 bullet points?
Output:


* **High-performance numerical computation:** JAX leverages the power of GPUs and TPUs to accelerate complex mathematical operations, making it ideal for scientific computing, machine learning, and data analysis.
* **Automatic differentiation:** JAX provides automatic differentiation capabilities, allowing you to compute gradients and optimize models efficiently. This simplifies the process of training deep learning models.
* **Functional programming:** JAX embraces functional programming principles, promoting code readability and maintainability. It offers a flexible and expressive syntax for defining and manipulating data. 


<end_of_turn>
  1. (Facultatif) Exécutez cette cellule pour libérer de la mémoire si vous avez créé le notebook et souhaitez essayer une autre requête. Vous pouvez ensuite instancier sampler à nouveau à l'étape 3, puis personnaliser et exécuter l'invite à l'étape 4.
del sampler

En savoir plus