Afficher sur ai.google.dev | Exécuter dans Google Colab | Ouvrir dans Vertex AI | Consulter le code source sur GitHub |
Ce tutoriel explique comment effectuer un échantillonnage/une inférence de base avec le modèle 2B Instruct RecurrentGemma à l'aide de la bibliothèque recurrentgemma
de Google DeepMind écrite avec JAX (bibliothèque de calcul numérique hautes performances), Flax (bibliothèque de réseaux de neurones basés sur JAX), Orbax (bibliothèque de jetons JAX/tokenSentence) pour l'entraînement (bibliothèques de jetons JAX/tokenSent2, comme les utilitaires Checkpoint2).SentencePiece Bien que le lin n'est pas utilisé directement dans ce carnet, il a été utilisé pour créer Gemma et RecurrentGemma (le modèle Griffin).
Ce notebook peut s'exécuter sur Google Colab avec le GPU T4. Pour ce faire, accédez à Modifier > Paramètres du notebook > sous Accélérateur matériel, sélectionnez GPU T4.
Configuration
Les sections suivantes expliquent les étapes de préparation d'un notebook à utiliser un modèle RecurrentGemma, y compris l'accès au modèle, l'obtention d'une clé API et la configuration de l'environnement d'exécution du notebook
Configurer l'accès à Kaggle pour Gemma
Pour suivre ce tutoriel, vous devez d'abord suivre des instructions de configuration semblables à celles de Gemma, à quelques exceptions près:
- Accédez à RecurrentGemma (au lieu de Gemma) sur kaggle.com.
- Sélectionnez un environnement d'exécution Colab disposant de suffisamment de ressources pour exécuter le modèle RecurrentGemma.
- Générez et configurez un nom d'utilisateur et une clé API Kaggle.
Une fois la configuration de RecurrentGemma terminée, passez à la section suivante, dans laquelle vous allez définir des variables d'environnement pour votre environnement Colab.
Définir des variables d'environnement
Définissez les variables d'environnement pour KAGGLE_USERNAME
et KAGGLE_KEY
. Lorsque le message "Accorder l'accès ?" s'affiche, messages, acceptez de fournir un accès secret.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Installer la bibliothèque recurrentgemma
Ce notebook se concentre sur l'utilisation d'un GPU Colab sans frais. Pour activer l'accélération matérielle, cliquez sur Modifier > Paramètres du notebook > Sélectionnez GPU T4 > Enregistrer.
Vous devez ensuite installer la bibliothèque Google DeepMind recurrentgemma
à partir de github.com/google-deepmind/recurrentgemma
. Si vous obtenez une erreur concernant le "résolveur de dépendances de pip", vous pouvez généralement l'ignorer.
pip install git+https://github.com/google-deepmind/recurrentgemma.git
Charger et préparer le modèle RecurrentGemma
- Chargez le modèle RecurrentGemma avec
kagglehub.model_download
, qui utilise trois arguments:
handle
: le gestionnaire de modèle de Kagglepath
(chaîne facultative) : chemin d'accès localforce_download
: (valeur booléenne facultative) force le téléchargement à nouveau du modèle.
RECURRENTGEMMA_VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
import kagglehub
RECURRENTGEMMA_PATH = kagglehub.model_download(f'google/recurrentgemma/flax/{RECURRENTGEMMA_VARIANT}')
Downloading from https://www.kaggle.com/api/v1/models/google/recurrentgemma/flax/2b-it/1/download... 100%|██████████| 3.85G/3.85G [00:52<00:00, 78.2MB/s] Extracting model files...
print('RECURRENTGEMMA_PATH:', RECURRENTGEMMA_PATH)
RECURRENTGEMMA_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1
- Vérifiez l'emplacement des pondérations du modèle et de la fonction de tokenisation, puis définissez les variables de chemin. Le répertoire de tokenisation se trouve dans le répertoire principal dans lequel vous avez téléchargé le modèle, tandis que les pondérations du modèle sont dans un sous-répertoire. Exemple :
- Le fichier
tokenizer.model
se trouvera dans/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1
). - Le point de contrôle du modèle se trouvera dans
/LOCAL/PATH/TO/recurrentgemma/flax/2b-it/1/2b-it
).
CKPT_PATH = os.path.join(RECURRENTGEMMA_PATH, RECURRENTGEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(RECURRENTGEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/2b-it TOKENIZER_PATH: /root/.cache/kagglehub/models/google/recurrentgemma/flax/2b-it/1/tokenizer.model
Effectuer un échantillonnage/une inférence
- Charger le point de contrôle du modèle RecurrentGemma à l'aide de la méthode
recurrentgemma.jax.load_parameters
L'argumentsharding
défini sur"single_device"
charge tous les paramètres du modèle sur un seul appareil.
import recurrentgemma
from recurrentgemma import jax as recurrentgemma
params = recurrentgemma.load_parameters(checkpoint_path=CKPT_PATH, sharding="single_device")
- Chargez la fonction de tokenisation du modèle RecurrentGemma, créée à l'aide de
sentencepiece.SentencePieceProcessor
:
import sentencepiece as spm
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
- Pour charger automatiquement la configuration appropriée à partir du point de contrôle du modèle RecurrentGemma, utilisez
recurrentgemma.GriffinConfig.from_flax_params_or_variables
. Instanciez ensuite le modèle Griffin avecrecurrentgemma.jax.Griffin
.
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(
flax_params_or_variables=params)
model = recurrentgemma.Griffin(model_config)
- Créez un
sampler
avecrecurrentgemma.jax.Sampler
au-dessus du point de contrôle/des pondérations du modèle RecurrentGemma et de la fonction de tokenisation:
sampler = recurrentgemma.Sampler(
model=model,
vocab=vocab,
params=params,
)
- Écrivez une requête dans
prompt
et effectuez l'inférence. Vous pouvez modifiertotal_generation_steps
(le nombre d'étapes effectuées lors de la génération d'une réponse. Cet exemple utilise50
pour préserver la mémoire de l'hôte).
prompt = [
"\n# 5+9=?",
]
reply = sampler(input_strings=prompt,
total_generation_steps=50,
)
for input_string, out_string in zip(prompt, reply.text):
print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:920: UserWarning: Some donated buffers were not usable: ShapedArray(int32[1,8]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation. warnings.warn("Some donated buffers were not usable:" Prompt: # 5+9=? Output: # Answer: 14 # Explanation: 5 + 9 = 14.
En savoir plus
- Pour en savoir plus sur la bibliothèque
recurrentgemma
de Google DeepMind sur GitHub, qui contient les docstrings des méthodes et des modules que vous avez utilisés dans ce tutoriel, tels querecurrentgemma.jax.load_parameters
,recurrentgemma.jax.Griffin
etrecurrentgemma.jax.Sampler
. - Les bibliothèques suivantes possèdent leurs propres sites de documentation: core JAX, Flax et Orbax.
- Pour obtenir de la documentation sur la fonction de tokenisation et de détokenisation
sentencepiece
, consultez le dépôt GitHubsentencepiece
de Google. - Pour obtenir de la documentation sur
kagglehub
, accédez àREADME.md
dans le dépôt GitHubkagglehub
de Kaggle. - Découvrez comment utiliser des modèles Gemma avec Google Cloud Vertex AI.
- Regardez l'émission RecurrentGemma: Moving Past Transformers pour des modèles de langage ouverts efficaces de Google DeepMind.
- Lisez l'article Griffin: Mixing Gated Linear Recurrences with article de GoogleDeepMind consacré à l'attention locale pour des modèles de langage efficaces, pour en savoir plus sur l'architecture des modèles utilisée par RecurrentGemma.