Crea un chatbot con Gemma

Visualizza su ai.google.dev Esegui in Google Colab Apri in Vertex AI Visualizza il codice sorgente su GitHub

I modelli linguistici di grandi dimensioni (LLM), come Gemma, eccellono nel generare risposte informative, rendendoli ideali per la creazione di assistenti virtuali e chatbot.

Convenzionalmente, gli LLM operano in modo stateless, ovvero non dispongono di una memoria intrinseca per archiviare le conversazioni passate. Ogni prompt o domanda viene elaborato in modo indipendente, senza tenere conto delle interazioni precedenti. Tuttavia, un aspetto cruciale della conversazione naturale è la capacità di conservare il contesto delle interazioni precedenti. Per superare questa limitazione e consentire agli LLM di mantenere il contesto della conversazione, è necessario fornire esplicitamente informazioni pertinenti come la cronologia della conversazione (o le parti pertinenti) in ogni nuovo prompt presentato all'LLM.

Questo tutorial mostra come sviluppare un chatbot utilizzando la variante del modello ottimizzato per le istruzioni di Gemma.

Imposta

Configurazione di Gemma

Per completare questo tutorial, devi prima completare le istruzioni di configurazione nella pagina di configurazione di Gemma. Le istruzioni di configurazione di Gemma mostrano come fare:

  • Accedi a Gemma su kaggle.com.
  • Seleziona un runtime Colab con risorse sufficienti per eseguire il modello Gemma 2B.
  • Genera e configura un nome utente e una chiave API Kaggle.

Dopo aver completato la configurazione di Gemma, passa alla sezione successiva, in cui imposterai le variabili di ambiente per il tuo ambiente Colab.

Imposta le variabili di ambiente

Imposta le variabili di ambiente per KAGGLE_USERNAME e KAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Installa le dipendenze

Installare Keras e KerasNLP.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q tensorflow-cpu
pip install -q -U keras-nlp tensorflow-hub
pip install -q -U keras>=3
pip install -q -U tensorflow-text

Seleziona un servizio di backend

Keras è un'API di deep learning multi-framework di alto livello progettata per la semplicità e la facilità d'uso. Keras 3 ti consente di scegliere il backend: TensorFlow, JAX o PyTorch. Per questo tutorial sono validi tutti e tre.

import os

# Select JAX as the backend
os.environ["KERAS_BACKEND"] = "jax"

# Pre-allocate 100% of TPU memory to minimize memory fragmentation
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

Importa pacchetti

Importare Keras e KerasNLP.

import keras
import keras_nlp

# for reproducibility
keras.utils.set_random_seed(42)

Crea un'istanza del modello

KerasNLP fornisce implementazioni di molte architetture di modelli popolari. In questo tutorial, creerai un'istanza del modello utilizzando GemmaCausalLM, un modello Gemma end-to-end per la creazione di modelli linguistici causali. Un modello linguistico causale prevede il token successivo in base a quelli precedenti.

Crea un'istanza del modello utilizzando il metodo from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_1.1_instruct_2b_en")
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'task.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'preprocessor.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...

La funzione GemmaCausalLM.from_preset() crea un'istanza del modello da un'architettura e pesi predefiniti. Nel codice precedente, la stringa "gemma_1.1_instruct_2b_en" specifica il preset del modello Gemma 2B con 2 miliardi di parametri. Sono disponibili anche modelli Gemma con parametri 7B, 9B e 27B. Puoi trovare le stringhe di codice per i modelli Gemma negli elenchi Varianti del modello su Kaggle.

Utilizza il metodo summary per ottenere maggiori informazioni sul modello:

gemma_lm.summary()

Come puoi vedere dal riepilogo, il modello ha 2,5 miliardi di parametri addestrabili.

Definisci le funzioni helper per la formattazione

from IPython.display import Markdown
import textwrap

def display_chat(prompt, text):
  formatted_prompt = "<font size='+1' color='brown'>🙋‍♂️<blockquote>" + prompt + "</blockquote></font>"
  text = text.replace('•', '  *')
  text = textwrap.indent(text, '> ', predicate=lambda _: True)
  formatted_text = "<font size='+1' color='teal'>🤖\n\n" + text + "\n</font>"
  return Markdown(formatted_prompt+formatted_text)

def to_markdown(text):
  text = text.replace('•', '  *')
  return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))

Creazione del chatbot in corso...

Il modello ottimizzato per le istruzioni di Gemma gemma_1.1_instruct_2b_en è ottimizzato per comprendere i seguenti token dei turni:

<start_of_turn>user\n  ... <end_of_turn>\n
<start_of_turn>model\n ... <end_of_turn>\n

Questo tutorial utilizza questi token per creare il chatbot. Per ulteriori informazioni sui token di controllo Gemma, consulta Istruzioni di formattazione e di sistema.

Crea un assistente per la chat per gestire lo stato della conversazione

class ChatState():
  """
  Manages the conversation history for a turn-based chatbot
  Follows the turn-based conversation guidelines for the Gemma family of models
  documented at https://ai.google.dev/gemma/docs/formatting
  """

  __START_TURN_USER__ = "<start_of_turn>user\n"
  __START_TURN_MODEL__ = "<start_of_turn>model\n"
  __END_TURN__ = "<end_of_turn>\n"

  def __init__(self, model, system=""):
    """
    Initializes the chat state.

    Args:
        model: The language model to use for generating responses.
        system: (Optional) System instructions or bot description.
    """
    self.model = model
    self.system = system
    self.history = []

  def add_to_history_as_user(self, message):
      """
      Adds a user message to the history with start/end turn markers.
      """
      self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)

  def add_to_history_as_model(self, message):
      """
      Adds a model response to the history with start/end turn markers.
      """
      self.history.append(self.__START_TURN_MODEL__ + message + self.__END_TURN__)

  def get_history(self):
      """
      Returns the entire chat history as a single string.
      """
      return "".join([*self.history])

  def get_full_prompt(self):
    """
    Builds the prompt for the language model, including history and system description.
    """
    prompt = self.get_history() + self.__START_TURN_MODEL__
    if len(self.system)>0:
      prompt = self.system + "\n" + prompt
    return prompt

  def send_message(self, message):
    """
    Handles sending a user message and getting a model response.

    Args:
        message: The user's message.

    Returns:
        The model's response.
    """
    self.add_to_history_as_user(message)
    prompt = self.get_full_prompt()
    response = self.model.generate(prompt, max_length=1024)
    result = response.replace(prompt, "")  # Extract only the new response
    self.add_to_history_as_model(result)
    return result

Chatta con il modello

Inizia a chattare con il modello.

chat = ChatState(gemma_lm)
message = "Tell me, in a few words,  how to compute all prime numbers up to 1000?"
display_chat(message, chat.send_message(message))

🙋‍♂️

Dimmi, in poche parole, come calcolare tutti i numeri primi fino a 1000?
🤖

Il setaccio di Eratostene è un metodo ampiamente usato per calcolare tutti i numeri primi fino a un certo limite. Implica la marcatura iterativa di multipli di ogni numero primo.

Continua la conversazione.

message = "Now in Python! No numpy, please!"
display_chat(message, chat.send_message(message))

🙋‍♂️

Ora in Python. Non essere idiota, per favore!
🤖

def prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True

message = "Thank you, it works! Can you explain the code in plain English?"
display_chat(message, chat.send_message(message))

🙋‍♂️

Grazie, funziona. Puoi spiegare il codice in un linguaggio semplice?
🤖

Il codice fornito definisce una funzione prime che verifica se un determinato numero sia primo o meno.

Come funziona:

  • La funzione prende un singolo argomento, n, che è il numero da verificare.

  • Innanzitutto, verifica se n è inferiore o uguale a 1. Se lo è, il numero non è primo, quindi la funzione restituisce False.

  • Quindi, entra in un ciclo che esegue l'iterazione attraverso i numeri da 2 alla radice quadrata di n.

  • Per ogni numero i, viene controllato se n è divisibile in modo uniforme per i (ad es. n % i == 0).

  • Se n è divisibile per i, la funzione restituisce False perché n non può essere primo se ha un divisore.

  • Se il ciclo viene completato senza trovare divisori per n, la funzione restituisce True, indicando che n è un numero primo.

Esempio di utilizzo:

>>> prime(2)
True
>>> prime(3)
True
>>> prime(4)
False
>>> prime(5)
True

Vantaggi del codice:

  • Si tratta di un algoritmo semplice ed efficiente per trovare i numeri primi.
  • È ampiamente utilizzato in varie applicazioni informatiche e matematiche.
message = "Great! Now add those explanations as comments in the code."
display_chat(message, chat.send_message(message))

🙋‍♂️

Bene. Ora aggiungi le spiegazioni come commenti nel codice.
🤖

def prime(n):
    """
    Checks whether a given number is prime or not.

    Args:
        n: The number to check.

    Returns:
        True if n is prime, False otherwise.
    """

    # Check if n is less than or equal to 1.
    if n <= 1:
        return False

    # Iterate through numbers from 2 to the square root of n.
    for i in range(2, int(n**0.5) + 1):
        # Check if n is divisible by i.
        if n % i == 0:
            return False

    # If the loop completes without finding any divisors for n, then n is prime.
    return True

Testa la risposta generata eseguendo il codice generato:

def is_prime(n):
  """
  Checks if a number is prime.

  Args:
    n: The number to check.

  Returns:
    True if n is prime, False otherwise.
  """

  # If n is less than or equal to 1, it is not prime.
  if n <= 1:
    return False

  # Iterate through all the numbers from 2 to the square root of n.
  for i in range(2, int(n**0.5) + 1):
    # If n is divisible by any of the numbers in the range from 2 to the square root of n, it is not prime.
    if n % i == 0:
      return False

  # If no divisors are found, n is prime.
  return True


# Initialize an empty list to store prime numbers.
primes = []

# Iterate through all the numbers from 2 to 1000.
for i in range(2, 1001):
  # If the number is prime, add it to the list.
  if is_prime(i):
    primes.append(i)

# Print the prime numbers.
print(primes)
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997]

Utilizza il metodo get_history per vedere in che modo la classe Chat ha conservato tutto il contesto.

print(chat.get_history())
<start_of_turn>user
Tell me, in a few words,  how to compute all prime numbers up to 1000?<end_of_turn>
<start_of_turn>model
The Sieve of Eratosthenes is a widely used method to compute all prime numbers up to a given limit. It involves iteratively marking out multiples of each prime number.<end_of_turn>
<start_of_turn>user
Now in Python! No numpy, please!<end_of_turn>
<start_of_turn>model

```python
def prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True
```<end_of_turn>
<start_of_turn>user
Thank you, it works! Can you explain the code in plain English?<end_of_turn>
<start_of_turn>model
The provided code defines a function `prime` that checks whether a given number is prime or not.

**How it works:**

- The function takes a single argument, `n`, which is the number to check.


- It first checks if `n` is less than or equal to 1. If it is, the number is not prime, so the function returns `False`.


- It then enters a loop that iterates through numbers from 2 to the square root of `n`.


- For each number `i`, it checks if `n` is divisible evenly by `i` (i.e., `n % i == 0`).


- If `n` is divisible by `i`, the function returns `False` because `n` cannot be prime if it has a divisor.


- If the loop completes without finding any divisors for `n`, the function returns `True`, indicating that `n` is a prime number.


**Example Usage:**

```python
>>> prime(2)
True
>>> prime(3)
True
>>> prime(4)
False
>>> prime(5)
True
```

**Benefits of this Code:**

- It is a simple and efficient algorithm for finding prime numbers.
- It is widely used in various computer science and mathematical applications.<end_of_turn>
<start_of_turn>user
Great! Now add those explanations as comments in the code.<end_of_turn>
<start_of_turn>model
```python
def prime(n):
    """
    Checks whether a given number is prime or not.

    Args:
        n: The number to check.

    Returns:
        True if n is prime, False otherwise.
    """

    # Check if n is less than or equal to 1.
    if n <= 1:
        return False

    # Iterate through numbers from 2 to the square root of n.
    for i in range(2, int(n**0.5) + 1):
        # Check if n is divisible by i.
        if n % i == 0:
            return False

    # If the loop completes without finding any divisors for n, then n is prime.
    return True
```<end_of_turn>

Riepilogo e ulteriori approfondimenti

In questo tutorial hai imparato a chattare con il modello ottimizzato di Gemma 2B Instruction utilizzando Keras su JAX.

Consulta queste guide e questi tutorial per scoprire di più su Gemma: