Exécuter des inférences avec Gemma à l'aide de Keras

Afficher sur ai.google.dev Exécuter dans Google Colab Ouvrir dans Vertex AI Consulter le code source sur GitHub

Ce tutoriel vous explique comment utiliser Gemma avec KerasNLP pour exécuter des inférences et générer du texte. Gemma est une famille de modèles ouverts légers et de pointe, élaborés à partir des recherches et des technologies utilisées pour créer les modèles Gemini. KerasNLP est un ensemble de modèles de traitement du langage naturel (TLN) implémentés dans Keras et exécutables sur JAX, PyTorch et TensorFlow.

Dans ce tutoriel, vous allez utiliser Gemma pour générer des réponses textuelles à plusieurs requêtes. Si vous débutez avec Keras, vous pouvez consulter la page Premiers pas avec Keras avant de commencer, mais ce n'est pas obligatoire. Vous en apprendrez davantage sur Keras tout au long de ce tutoriel.

Préparation

Configuration de Gemma

Pour suivre ce tutoriel, vous devez d'abord suivre les instructions de configuration de Gemma. Les instructions de configuration de Gemma vous expliquent comment:

  • Accédez à Gemma sur kaggle.com.
  • Sélectionnez un environnement d'exécution Colab disposant de suffisamment de ressources pour exécuter le modèle Gemma 2B.
  • Générez et configurez un nom d'utilisateur et une clé API Kaggle.

Une fois la configuration de Gemma terminée, passez à la section suivante, dans laquelle vous allez définir des variables d'environnement pour votre environnement Colab.

Définir des variables d'environnement

Définissez les variables d'environnement pour KAGGLE_USERNAME et KAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Installer des dépendances

installer Keras et KerasNLP ;

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U keras>=3

Sélectionnez un backend

Keras est une API de deep learning multi-framework de haut niveau, conçue pour être simple et facile à utiliser. Keras 3 vous permet de choisir le backend: TensorFlow, JAX ou PyTorch. Les trois fonctionnent pour ce tutoriel.

import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Importer des packages

Importer Keras et KerasNLP

import keras
import keras_nlp

Créer un modèle

KerasNLP fournit des implémentations de nombreuses architectures de modèles courantes. Dans ce tutoriel, vous allez créer un modèle à l'aide de GemmaCausalLM, un modèle Gemma de bout en bout destiné à la modélisation du langage causale. Un modèle de langage causal prédit le jeton suivant en fonction des jetons précédents.

Créez le modèle à l'aide de la méthode from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...

La fonction GemmaCausalLM.from_preset() instancie le modèle à partir d'une architecture et de pondérations prédéfinies. Dans le code ci-dessus, la chaîne "gemma_2b_en" spécifie le préréglage du modèle Gemma 2B avec deux milliards de paramètres. Des modèles Gemma avec des paramètres 7B, 9B et 27B sont également disponibles. Vous trouverez les chaînes de code des modèles Gemma dans la liste des variantes de modèle sur kaggle.com.

Utilisez summary pour obtenir plus d'informations sur le modèle:

gemma_lm.summary()

Comme vous pouvez le voir dans le résumé, le modèle comporte 2,5 milliards de paramètres pouvant être entraînés.

Générer du texte

Il est maintenant temps de générer du texte ! Le modèle dispose d'une méthode generate qui génère du texte en fonction d'une requête. L'argument facultatif max_length spécifie la longueur maximale de la séquence générée.

Essayez avec l'invite "What is the meaning of life?".

gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives'

Réessayez d'appeler generate avec une autre requête.

gemma_lm.generate("How does the brain work?", max_length=64)
'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up'

Si vous exécutez l'application sur des backends JAX ou TensorFlow, vous remarquerez que le deuxième appel generate est renvoyé presque instantanément. En effet, chaque appel à generate pour une taille de lot donnée et max_length est compilé avec XLA. La première exécution est coûteuse, mais les exécutions suivantes sont beaucoup plus rapides.

Vous pouvez également fournir des requêtes par lot en utilisant une liste en entrée:

gemma_lm.generate(
    ["What is the meaning of life?",
     "How does the brain work?"],
    max_length=64)
['What is the meaning of life?\n\nThe question is one of the most important questions in the world.\n\nIt’s the question that has been asked by philosophers, theologians, and scientists for centuries.\n\nAnd it’s the question that has been asked by people who are looking for answers to their own lives',
 'How does the brain work?\n\nThe brain is the most complex organ in the human body. It is responsible for controlling all of the body’s functions, including breathing, heart rate, digestion, and more. The brain is also responsible for thinking, feeling, and making decisions.\n\nThe brain is made up']

Facultatif: Essayer un autre échantillonneur

Vous pouvez contrôler la stratégie de génération pour GemmaCausalLM en définissant l'argument sampler sur compile(). Par défaut, l'échantillonnage "greedy" est utilisé.

Pour effectuer un test, essayez de définir une stratégie "top_k":

gemma_lm.compile(sampler="top_k")
gemma_lm.generate("What is the meaning of life?", max_length=64)
'What is the meaning of life? That was a question I asked myself as I was driving home from work one night in 2012. I was driving through the city of San Bernardino, and all I could think was, “What the heck am I doing?”\n\nMy life was completely different. I'

Alors que l'algorithme gourmand par défaut choisit toujours le jeton avec la probabilité la plus élevée, l'algorithme des top-K choisit au hasard le jeton suivant parmi les jetons de probabilité supérieure de K.

Vous n'avez pas besoin de spécifier d'échantillonneur, et vous pouvez ignorer le dernier extrait de code s'il n'est pas utile à votre cas d'utilisation. Pour en savoir plus sur les échantillonneurs disponibles, consultez la section Échantillonneurs.

Étapes suivantes

Dans ce tutoriel, vous avez appris à générer du texte à l'aide de KerasNLP et Gemma. Voici quelques suggestions pour la suite: