Inférence avec RecurrentGemma à l'aide de JAX et Flax

Afficher sur ai.google.dev Exécuter dans Google Colab Ouvrir dans Vertex AI Consulter le code source sur GitHub

Ce tutoriel explique comment effectuer un échantillonnage/une inférence de base avec le modèle 2B Instruct RecurrentGemma à l'aide de la bibliothèque recurrentgemma de Google DeepMind écrite avec JAX (bibliothèque de calcul numérique hautes performances), Flax (bibliothèque de réseaux de neurones basés sur JAX), Orbax (bibliothèque de jetons JAX/tokenSentence) pour l'entraînement (bibliothèques de jetons JAX/tokenSent2, comme les utilitaires Checkpoint2).SentencePiece Bien que le lin n'est pas utilisé directement dans ce carnet, il a été utilisé pour créer Gemma et RecurrentGemma (le modèle Griffin).

Ce notebook peut s'exécuter sur Google Colab avec le GPU T4. Pour ce faire, accédez à Modifier > Paramètres du notebook > sous Accélérateur matériel, sélectionnez GPU T4.

Configuration

Les sections suivantes expliquent les étapes de préparation d'un notebook à utiliser un modèle RecurrentGemma, y compris l'accès au modèle, l'obtention d'une clé API et la configuration de l'environnement d'exécution du notebook

Configurer l'accès à Kaggle pour Gemma

Pour suivre ce tutoriel, vous devez d'abord suivre des instructions de configuration semblables à celles de Gemma, à quelques exceptions près:

  • Accédez à RecurrentGemma (au lieu de Gemma) sur kaggle.com.
  • Sélectionnez un environnement d'exécution Colab disposant de suffisamment de ressources pour exécuter le modèle RecurrentGemma.
  • Générez et configurez un nom d'utilisateur et une clé API Kaggle.

Une fois la configuration de RecurrentGemma terminée, passez à la section suivante, dans laquelle vous allez définir des variables d'environnement pour votre environnement Colab.

Définir des variables d'environnement

Définissez les variables d'environnement pour KAGGLE_USERNAME et KAGGLE_KEY. Lorsque le message "Accorder l'accès ?" s'affiche, messages, acceptez de fournir un accès secret.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Installer la bibliothèque recurrentgemma

Ce notebook se concentre sur l'utilisation d'un GPU Colab sans frais. Pour activer l'accélération matérielle, cliquez sur Modifier > Paramètres du notebook > Sélectionnez GPU T4 > Enregistrer.

Vous devez ensuite installer la bibliothèque Google DeepMind recurrentgemma à partir de github.com/google-deepmind/recurrentgemma. Si vous obtenez une erreur concernant le "résolveur de dépendances de pip", vous pouvez généralement l'ignorer.

pip install git+https://github.com/google-deepmind/recurrentgemma.git

Charger et préparer le modèle RecurrentGemma

  1. Chargez le modèle RecurrentGemma avec kagglehub.model_download, qui utilise trois arguments:
  • handle: le gestionnaire de modèle de Kaggle
  • path (chaîne facultative) : chemin d'accès local
  • force_download: (valeur booléenne facultative) force le téléchargement à nouveau du modèle.
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub

RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download...
100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s]
Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
  1. Vérifiez l'emplacement des pondérations du modèle et de la fonction de tokenisation, puis définissez les variables de chemin. Le répertoire de tokenisation se trouve dans le répertoire principal dans lequel vous avez téléchargé le modèle, tandis que les pondérations du modèle sont dans un sous-répertoire. Exemple :
  • Le fichier tokenizer.model se trouvera dans /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1).
  • Le point de contrôle du modèle se trouvera dans /LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model

Effectuer un échantillonnage/une inférence

  1. Charger le point de contrôle du modèle RecurrentGemma à l'aide de la méthode recurrentgemma.jax.load_parameters L'argument sharding défini sur "single_device" charge tous les paramètres du modèle sur un seul appareil.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma

params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
  1. Chargez la fonction de tokenisation du modèle RecurrentGemma, créée à l'aide de sentencepiece.SentencePieceProcessor:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
  1. Pour charger automatiquement la configuration appropriée à partir du point de contrôle du modèle RecurrentGemma, utilisez recurrentgemma.GriffinConfig.from_flax_params_or_variables. Instanciez ensuite le modèle Griffin avec recurrentgemma.jax.Griffin.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
    flax_params_or_variables=params)

model = recurrentgemma.Griffin(model_config)
  1. Créez un sampler avec recurrentgemma.jax.Sampler au-dessus du point de contrôle/des pondérations du modèle RecurrentGemma et de la fonction de tokenisation:
sampler = recurrentgemma.Sampler(
    model=model,
    vocab=vocab,
    params=params,
)
  1. Écrivez une requête dans prompt et effectuez l'inférence. Vous pouvez modifier total_generation_steps (le nombre d'étapes effectuées lors de la génération d'une réponse. Cet exemple utilise 50 pour préserver la mémoire de l'hôte).
prompt = [
    "\n# 5+9=?",
]

reply = sampler(input_strings=prompt,
                total_generation_steps=50,
                )

for input_string, out_string in zip(prompt, reply.text):
    print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
Prompt:

# 5+9=?
Output:


# Answer: 14

# Explanation: 5 + 9 = 14.

En savoir plus