Créer un chatbot avec Gemma

Afficher sur ai.google.dev Exécuter dans Google Colab Ouvrir dans Vertex AI Consulter le code source sur GitHub

Les grands modèles de langage (LLM) tels que Gemma sont particulièrement efficaces pour générer des réponses informatives, ce qui les rend parfaits pour créer des assistants virtuels et des chatbots.

De manière conventionnelle, les LLM fonctionnent de manière sans état, ce qui signifie qu'ils ne disposent pas de mémoire inhérente pour stocker les conversations précédentes. Chaque requête ou question est traitée indépendamment, sans tenir compte des interactions antérieures. Cependant, un aspect crucial de la conversation naturelle est la capacité à retenir le contexte des interactions antérieures. Pour surmonter cette limitation et permettre aux LLM de maintenir le contexte de la conversation, ils doivent recevoir explicitement des informations pertinentes, telles que l'historique de la conversation (ou les parties pertinentes) dans chaque nouvelle requête présentée au LLM.

Ce tutoriel vous explique comment développer un chatbot à l'aide de la variante de modèle réglé par instruction de Gemma.

Préparation

Configuration de Gemma

Pour suivre ce tutoriel, vous devez d'abord suivre les instructions de configuration de Gemma. Les instructions de configuration de Gemma vous expliquent comment:

  • Accédez à Gemma sur kaggle.com.
  • Sélectionnez un environnement d'exécution Colab disposant de suffisamment de ressources pour exécuter le modèle Gemma 2B.
  • Générez et configurez un nom d'utilisateur et une clé API Kaggle.

Une fois la configuration de Gemma terminée, passez à la section suivante, dans laquelle vous allez définir des variables d'environnement pour votre environnement Colab.

Définir des variables d'environnement

Définissez les variables d'environnement pour KAGGLE_USERNAME et KAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Installer des dépendances

installer Keras et KerasNLP ;

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q tensorflow-cpu
pip install -q -U keras-nlp tensorflow-hub
pip install -q -U keras>=3
pip install -q -U tensorflow-text

Sélectionnez un backend

Keras est une API de deep learning multi-framework de haut niveau, conçue pour être simple et facile à utiliser. Keras 3 vous permet de choisir le backend: TensorFlow, JAX ou PyTorch. Les trois fonctionnent pour ce tutoriel.

import os

# Select JAX as the backend
os.environ["KERAS_BACKEND"] = "jax"

# Pre-allocate 100% of TPU memory to minimize memory fragmentation
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

Importer des packages

Importer Keras et KerasNLP

import keras
import keras_nlp

# for reproducibility
keras.utils.set_random_seed(42)

Instancier le modèle

KerasNLP fournit des implémentations de nombreuses architectures de modèles courantes. Dans ce tutoriel, vous allez instancier le modèle à l'aide de GemmaCausalLM, un modèle Gemma de bout en bout destiné à la modélisation du langage causale. Un modèle de langage causal prédit le jeton suivant en fonction des jetons précédents.

Instanciez le modèle à l'aide de la méthode from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_1.1_instruct_2b_en")
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'task.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'preprocessor.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...

La fonction GemmaCausalLM.from_preset() instancie le modèle à partir d'une architecture et de pondérations prédéfinies. Dans le code ci-dessus, la chaîne "gemma_1.1_instruct_2b_en" spécifie le préréglage du modèle Gemma 2B avec deux milliards de paramètres. Des modèles Gemma avec des paramètres 7B, 9B et 27B sont également disponibles. Vous trouverez les chaînes de code des modèles Gemma dans les listes des variantes de modèle sur Kaggle.

Utilisez la méthode summary pour obtenir plus d'informations sur le modèle:

gemma_lm.summary()

Comme vous pouvez le voir dans le résumé, le modèle comporte 2,5 milliards de paramètres pouvant être entraînés.

Définir les fonctions d'assistance de mise en forme

from IPython.display import Markdown
import textwrap

def display_chat(prompt, text):
  formatted_prompt = "<font size='+1' color='brown'>🙋‍♂️<blockquote>" + prompt + "</blockquote></font>"
  text = text.replace('•', '  *')
  text = textwrap.indent(text, '> ', predicate=lambda _: True)
  formatted_text = "<font size='+1' color='teal'>🤖\n\n" + text + "\n</font>"
  return Markdown(formatted_prompt+formatted_text)

def to_markdown(text):
  text = text.replace('•', '  *')
  return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))

Créer le chatbot

Le modèle d'instruction Gemma gemma_1.1_instruct_2b_en est réglé pour interpréter les jetons de tour suivants:

<start_of_turn>user\n  ... <end_of_turn>\n
<start_of_turn>model\n ... <end_of_turn>\n

Ce tutoriel utilise ces jetons pour créer le chatbot. Pour en savoir plus sur les jetons de contrôle Gemma, consultez la section Formatage et instructions système.

Créer une aide de chat pour gérer l'état de la conversation

class ChatState():
  """
  Manages the conversation history for a turn-based chatbot
  Follows the turn-based conversation guidelines for the Gemma family of models
  documented at https://ai.google.dev/gemma/docs/formatting
  """

  __START_TURN_USER__ = "<start_of_turn>user\n"
  __START_TURN_MODEL__ = "<start_of_turn>model\n"
  __END_TURN__ = "<end_of_turn>\n"

  def __init__(self, model, system=""):
    """
    Initializes the chat state.

    Args:
        model: The language model to use for generating responses.
        system: (Optional) System instructions or bot description.
    """
    self.model = model
    self.system = system
    self.history = []

  def add_to_history_as_user(self, message):
      """
      Adds a user message to the history with start/end turn markers.
      """
      self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)

  def add_to_history_as_model(self, message):
      """
      Adds a model response to the history with start/end turn markers.
      """
      self.history.append(self.__START_TURN_MODEL__ + message + self.__END_TURN__)

  def get_history(self):
      """
      Returns the entire chat history as a single string.
      """
      return "".join([*self.history])

  def get_full_prompt(self):
    """
    Builds the prompt for the language model, including history and system description.
    """
    prompt = self.get_history() + self.__START_TURN_MODEL__
    if len(self.system)>0:
      prompt = self.system + "\n" + prompt
    return prompt

  def send_message(self, message):
    """
    Handles sending a user message and getting a model response.

    Args:
        message: The user's message.

    Returns:
        The model's response.
    """
    self.add_to_history_as_user(message)
    prompt = self.get_full_prompt()
    response = self.model.generate(prompt, max_length=1024)
    result = response.replace(prompt, "")  # Extract only the new response
    self.add_to_history_as_model(result)
    return result

Discuter avec le modèle

Commencez à discuter avec le modèle.

chat = ChatState(gemma_lm)
message = "Tell me, in a few words,  how to compute all prime numbers up to 1000?"
display_chat(message, chat.send_message(message))

🙋‍♂️

Dites-moi, en quelques mots, comment calculer tous les nombres premiers jusqu'à 1 000 ?
usercontent

Le tamis d'Ératosthène est une méthode largement utilisée pour calculer tous les nombres premiers jusqu'à une limite donnée. Elle implique de marquer de manière itérative les multiples de chaque nombre premier.

Poursuivez la conversation.

message = "Now in Python! No numpy, please!"
display_chat(message, chat.send_message(message))

🙋‍♂️

Maintenant en Python ! Pas de numpy, s'il vous plaît !
Sandbox

def prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True

message = "Thank you, it works! Can you explain the code in plain English?"
display_chat(message, chat.send_message(message))

🙋‍♂️

Merci, ça a marché ! Pouvez-vous expliquer le code en toutes lettres ?
friendly

Le code fourni définit une fonction prime qui vérifie si un nombre donné est premier ou non.

Fonctionnement:

  • La fonction utilise un seul argument, n, qui est le nombre à vérifier.

  • Elle vérifie d'abord si n est inférieur ou égal à 1. Si c'est le cas, le nombre n'est pas premier, donc la fonction renvoie False.

  • Il entre ensuite dans une boucle qui parcourt des nombres de 2 jusqu'à la racine carrée de n.

  • Pour chaque nombre i, il vérifie si n est divisible de manière égale par i (autrement dit, n % i == 0).

  • Si n est divisible par i, la fonction renvoie False, car n ne peut pas être premier s'il a un diviseur.

  • Si la boucle se termine sans trouver de diviseurs pour n, la fonction renvoie True, ce qui indique que n est un nombre premier.

Exemple d'utilisation:

>>> prime(2)
True
>>> prime(3)
True
>>> prime(4)
False
>>> prime(5)
True

Avantages de ce code:

  • Il s'agit d'un algorithme simple et efficace permettant de trouver des nombres premiers.
  • Il est largement utilisé dans diverses applications mathématiques et informatiques.
message = "Great! Now add those explanations as comments in the code."
display_chat(message, chat.send_message(message))

🙋‍♂️

Parfait ! Ajoutez maintenant ces explications sous forme de commentaires dans le code.
usercontent

def prime(n):
    """
    Checks whether a given number is prime or not.

    Args:
        n: The number to check.

    Returns:
        True if n is prime, False otherwise.
    """

    # Check if n is less than or equal to 1.
    if n <= 1:
        return False

    # Iterate through numbers from 2 to the square root of n.
    for i in range(2, int(n**0.5) + 1):
        # Check if n is divisible by i.
        if n % i == 0:
            return False

    # If the loop completes without finding any divisors for n, then n is prime.
    return True

Testez la réponse générée en exécutant le code généré:

def is_prime(n):
  """
  Checks if a number is prime.

  Args:
    n: The number to check.

  Returns:
    True if n is prime, False otherwise.
  """

  # If n is less than or equal to 1, it is not prime.
  if n <= 1:
    return False

  # Iterate through all the numbers from 2 to the square root of n.
  for i in range(2, int(n**0.5) + 1):
    # If n is divisible by any of the numbers in the range from 2 to the square root of n, it is not prime.
    if n % i == 0:
      return False

  # If no divisors are found, n is prime.
  return True


# Initialize an empty list to store prime numbers.
primes = []

# Iterate through all the numbers from 2 to 1000.
for i in range(2, 1001):
  # If the number is prime, add it to the list.
  if is_prime(i):
    primes.append(i)

# Print the prime numbers.
print(primes)
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997]

Utilisez la méthode get_history pour voir comment l'ensemble du contexte a été conservé par la classe Chat.

print(chat.get_history())
<start_of_turn>user
Tell me, in a few words,  how to compute all prime numbers up to 1000?<end_of_turn>
<start_of_turn>model
The Sieve of Eratosthenes is a widely used method to compute all prime numbers up to a given limit. It involves iteratively marking out multiples of each prime number.<end_of_turn>
<start_of_turn>user
Now in Python! No numpy, please!<end_of_turn>
<start_of_turn>model

```python
def prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True
```<end_of_turn>
<start_of_turn>user
Thank you, it works! Can you explain the code in plain English?<end_of_turn>
<start_of_turn>model
The provided code defines a function `prime` that checks whether a given number is prime or not.

**How it works:**

- The function takes a single argument, `n`, which is the number to check.


- It first checks if `n` is less than or equal to 1. If it is, the number is not prime, so the function returns `False`.


- It then enters a loop that iterates through numbers from 2 to the square root of `n`.


- For each number `i`, it checks if `n` is divisible evenly by `i` (i.e., `n % i == 0`).


- If `n` is divisible by `i`, the function returns `False` because `n` cannot be prime if it has a divisor.


- If the loop completes without finding any divisors for `n`, the function returns `True`, indicating that `n` is a prime number.


**Example Usage:**

```python
>>> prime(2)
True
>>> prime(3)
True
>>> prime(4)
False
>>> prime(5)
True
```

**Benefits of this Code:**

- It is a simple and efficient algorithm for finding prime numbers.
- It is widely used in various computer science and mathematical applications.<end_of_turn>
<start_of_turn>user
Great! Now add those explanations as comments in the code.<end_of_turn>
<start_of_turn>model
```python
def prime(n):
    """
    Checks whether a given number is prime or not.

    Args:
        n: The number to check.

    Returns:
        True if n is prime, False otherwise.
    """

    # Check if n is less than or equal to 1.
    if n <= 1:
        return False

    # Iterate through numbers from 2 to the square root of n.
    for i in range(2, int(n**0.5) + 1):
        # Check if n is divisible by i.
        if n % i == 0:
            return False

    # If the loop completes without finding any divisors for n, then n is prime.
    return True
```<end_of_turn>

Résumé et autres documents

Dans ce tutoriel, vous avez appris à discuter avec le modèle réglé pour instruction Gemma 2B à l'aide de Keras sur JAX.

Consultez ces guides et tutoriels pour en savoir plus sur Gemma: