Premiers pas avec Gemma en utilisant KerasNLP

Voir sur ai.google.dev Exécuter dans Google Colab Ouvrir dans Vertex AI Afficher la source sur GitHub

Ce tutoriel vous explique comment utiliser Gemma avec KerasNLP. Gemma est une famille de modèles ouverts légers et de pointe, élaborés à partir des mêmes recherches et technologies que celles utilisées pour créer les modèles Gemini. KerasNLP est un ensemble de modèles de traitement du langage naturel (TLN) implémentés dans Keras et exécutables sur JAX, PyTorch et TensorFlow.

Dans ce tutoriel, vous allez utiliser Gemma pour générer des réponses textuelles à plusieurs requêtes. Si vous débutez avec Keras, vous pouvez consulter la page Premiers pas avec Keras avant de commencer, mais ce n'est pas obligatoire. Vous en apprendrez davantage sur Keras au fur et à mesure de ce tutoriel.

Préparation

Configuration de Gemma

Pour suivre ce tutoriel, vous devez d'abord suivre les instructions de configuration de l'article Configuration de Gemma. Les instructions de configuration de Gemma vous indiquent comment effectuer les opérations suivantes:

  • Accédez à Gemma sur kaggle.com.
  • Sélectionnez un environnement d'exécution Colab disposant de suffisamment de ressources pour exécuter le modèle Gemma 2B.
  • Générer et configurer un nom d'utilisateur et une clé d'API Kaggle.

Une fois la configuration de Gemma terminée, passez à la section suivante, dans laquelle vous allez définir les variables d'environnement de votre environnement Colab.

Définir des variables d'environnement

Définissez les variables d'environnement pour KAGGLE_USERNAME et KAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Installer des dépendances

installer Keras et KerasNLP ;

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U keras>=3

Sélectionnez un backend

Keras est une API de deep learning multi-framework de haut niveau, conçue pour être simple et facile à utiliser. Avec Keras 3, vous pouvez choisir le backend: TensorFlow, JAX ou PyTorch. Les trois fonctionneront pour ce tutoriel.

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Importer des packages

Importer Keras et KerasNLP

import keras
import keras_nlp

Créer un modèle

KerasNLP fournit des implémentations de nombreuses architectures de modèles courantes. Dans ce tutoriel, vous allez créer un modèle à l'aide de GemmaCausalLM, un modèle Gemma de bout en bout destiné à la modélisation du langage causal. Un modèle de langage causal prédit le jeton suivant en fonction des jetons précédents.

Créez le modèle à l'aide de la méthode from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...

from_preset instancie le modèle à partir d'une architecture et de pondérations prédéfinies. Dans le code ci-dessus, la chaîne "gemma_2b_en" spécifie l'architecture prédéfinie: un modèle Gemma avec deux milliards de paramètres.

Utilisez summary pour obtenir plus d'informations sur le modèle:

gemma_lm.summary()

Comme vous pouvez le voir dans le résumé, le modèle comporte 2,5 milliards de paramètres entraînables.

Générer du texte

Vous pouvez maintenant générer du texte. Le modèle comporte une méthode generate qui génère du texte en fonction d'une requête. L'argument facultatif max_length spécifie la longueur maximale de la séquence générée.

Essayez-le avec l'invite "What is the meaning of life?".

gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives'

Réessayez d'appeler generate avec une autre invite.

gemma_lm.generate("How does the brain work?", max_length=64)
'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up'

Si vous exécutez la commande sur des backends JAX ou TensorFlow, vous remarquerez que le deuxième appel generate est renvoyé presque instantanément. En effet, chaque appel à generate pour une taille de lot donnée et à max_length est compilé avec XLA. La première exécution est coûteuse, mais les exécutions suivantes sont beaucoup plus rapides.

Vous pouvez également fournir des requêtes par lot en utilisant une liste en tant qu'entrée:

gemma_lm.generate(
    ["What is the meaning of life?",
     "How does the brain work?"],
    max_length=64)
['What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives',
 'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up']

Facultatif: Essayez un autre échantillonneur

Vous pouvez contrôler la stratégie de génération pour GemmaCausalLM en définissant l'argument sampler sur compile(). Par défaut, l'échantillonnage de "greedy" sera utilisé.

Pour effectuer un test, essayez de définir une stratégie "top_k":

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life? That was a question I asked myself as I was driving home from work one night in 2012. I was driving through the city of San Bernardino, and all I could think was, “What the heck am I doing?”\n\nMy life was completely different. I'

Alors que l'algorithme gourmand par défaut sélectionne toujours le jeton ayant la probabilité la plus élevée, l'algorithme top-K choisit aléatoirement le jeton suivant parmi les jetons de probabilité top-K.

Vous n'avez pas besoin de spécifier un échantillonneur. Vous pouvez ignorer le dernier extrait de code s'il n'est pas utile à votre cas d'utilisation. Pour en savoir plus sur les échantillonneurs disponibles, consultez la page Échantillonneurs.

Étapes suivantes

Dans ce tutoriel, vous avez appris à générer du texte à l'aide de KerasNLP et Gemma. Voici quelques suggestions d'informations à retenir: