Inférence avec Gemma à l'aide de JAX et Flax

Voir sur ai.google.dev Exécuter dans Google Colab Ouvrir dans Vertex AI Afficher la source sur GitHub

Présentation

Gemma est une famille de grands modèles de langage ouverts, légers et de pointe, qui est basé sur la technologie et la recherche Google DeepMind Gemini. Ce tutoriel explique comment effectuer un échantillonnage/inférence de base avec le modèle Instruct Gemma 2B à l'aide de la bibliothèque gemma de Google DeepMind, écrite avec JAX (une bibliothèque de calcul numérique hautes performances), Flax (bibliothèque de réseaux de neurones basée sur JAX), Orbax (bibliothèque JAX pour les utilitaires d'entraînement tels que le jeton de contrôle) et SentencePiece Flax n'est pas utilisé directement dans ce notebook, mais il a servi à créer Gemma.

Ce notebook peut s'exécuter sur Google Colab avec un GPU T4 sans frais (accédez à Modifier > Paramètres du notebook > Accélérateur matériel, sélectionnez GPU T4).

Préparation

1. Configurer l'accès à Kaggle pour Gemma

Pour suivre ce tutoriel, vous devez d'abord suivre les instructions de configuration de l'article Configuration de Gemma, qui expliquent comment effectuer les opérations suivantes:

  • Accédez à Gemma sur kaggle.com.
  • Sélectionnez un environnement d'exécution Colab disposant de suffisamment de ressources pour exécuter le modèle Gemma.
  • Générer et configurer un nom d'utilisateur et une clé d'API Kaggle.

Une fois la configuration de Gemma terminée, passez à la section suivante, dans laquelle vous allez définir les variables d'environnement de votre environnement Colab.

2. Définir des variables d'environnement

Définissez les variables d'environnement pour KAGGLE_USERNAME et KAGGLE_KEY. Lorsque le message "Accorder l'accès ?" s'affiche, acceptez de fournir un accès secret.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. Installer la bibliothèque gemma

Ce notebook se concentre sur l'utilisation d'un GPU Colab sans frais. Pour activer l'accélération matérielle, cliquez sur Edit > Notebook settings (Modifier > Paramètres du notebook) > T4 GPU > Save (Enregistrer).

Vous devez ensuite installer la bibliothèque Google DeepMind gemma à partir de github.com/google-deepmind/gemma. Si vous obtenez une erreur concernant le "résolveur de dépendances de pip", vous pouvez généralement l'ignorer.

pip install -q git+https://github.com/google-deepmind/gemma.git

Charger et préparer le modèle Gemma

  1. Chargez le modèle Gemma avec kagglehub.model_download, qui accepte trois arguments:
  • handle: handle du modèle de Kaggle
  • path: (chaîne facultative) chemin d'accès local.
  • force_download: (valeur booléenne facultative) force le téléchargement à nouveau du modèle.
GEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/gemma/flax/2b-it/2/download...
100%|██████████| 3.67G/3.67G [00:35<00:00, 110MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2
  1. Vérifiez l'emplacement des pondérations du modèle et de la fonction de tokenisation, puis définissez les variables de chemin d'accès. Le répertoire de tokenisation se trouve dans le répertoire principal où vous avez téléchargé le modèle, tandis que les pondérations du modèle figurent dans un sous-répertoire. Exemple :
  • Le fichier tokenizer.model se trouvera au format /LOCAL/PATH/TO/gemma/flax/2b-it/2.
  • Le point de contrôle du modèle se trouve dans /LOCAL/PATH/TO/gemma/flax/2b-it/2/2b-it).
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/gemma/flax/2b-it/2/tokenizer.model

Effectuer un échantillonnage/inférence

  1. Chargez et formatez le point de contrôle du modèle Gemma avec la méthode gemma.params.load_and_format_params:
from gemma import params as params_lib

params = params_lib.load_and_format_params(CKPT_PATH)
  1. Chargez le générateur de jetons Gemma, créé à l'aide de sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Pour charger automatiquement la configuration correcte à partir du point de contrôle du modèle Gemma, utilisez gemma.transformer.TransformerConfig. L'argument cache_size correspond au nombre de pas de temps dans le cache Gemma Transformer. Ensuite, instanciez le modèle Gemma en tant que transformer avec gemma.transformer.Transformer (qui hérite de flax.linen.Module).
from gemma import transformer as transformer_lib

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)
  1. Créez un sampler avec gemma.sampler.Sampler au-dessus du point de contrôle/pondérations du modèle Gemma et de la fonction de tokenisation:
from gemma import sampler as sampler_lib

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)
  1. Écrivez une requête dans input_batch et effectuez une inférence. Vous pouvez ajuster total_generation_steps (le nombre d'étapes effectuées lors de la génération d'une réponse ; cet exemple utilise 100 pour préserver la mémoire hôte).
prompt = [
    "\n# What is the meaning of life?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=100,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
Prompt:

# What is the meaning of life?
Output:


The question of what the meaning of life is one that has occupied the minds of philosophers, theologians, and individuals for centuries. There is no single, universally accepted answer, but there are many different perspectives on this complex and multifaceted question.

**Some common perspectives on the meaning of life include:**

* **Biological perspective:** From a biological standpoint, the meaning of life is to survive and reproduce.
* **Existential perspective:** Existentialists believe that life is not inherently meaningful and that
  1. (Facultatif) Si vous avez terminé le notebook et que vous souhaitez essayer une autre invite, exécutez cette cellule pour libérer de la mémoire. Vous pouvez ensuite instancier à nouveau sampler à l'étape 3, puis personnaliser et exécuter l'invite de l'étape 4.
del sampler

En savoir plus