Mit Gemma einen Chatbot erstellen

Auf ai.google.dev ansehen In Google Colab ausführen In Vertex AI öffnen Quelle auf GitHub ansehen

Large Language Models (LLMs) wie Gemma liefern besonders informative Antworten und sind daher ideal für die Entwicklung von virtuellen Assistenten und Chatbots.

LLMs sind üblicherweise zustandslos, was bedeutet, dass ihnen ein inhärenter Speicher zum Speichern vergangener Unterhaltungen fehlt. Jede Aufforderung oder Frage wird unabhängig und ohne Berücksichtigung vorheriger Interaktionen verarbeitet. Ein entscheidender Aspekt eines natürlichen Gesprächs ist jedoch die Fähigkeit, den Kontext aus früheren Interaktionen zu bewahren. Um diese Einschränkung zu überwinden und LLMs zu ermöglichen, den Unterhaltungskontext aufrechtzuerhalten, müssen sie in jedem neuen Prompt, der dem LLM präsentiert wird, explizit relevante Informationen wie den Unterhaltungsverlauf (oder entsprechende Teile) erhalten.

In dieser Anleitung erfahren Sie, wie Sie einen Chatbot mit der auf Anweisungen abgestimmten Modellvariante von Gemma entwickeln.

Einrichtung

Gemma-Einrichtung

Um diese Anleitung abzuschließen, müssen Sie zuerst die Schritte unter Gemma-Einrichtung ausführen. In der Anleitung zur Einrichtung von Gemma erfahren Sie, wie Sie Folgendes tun können:

  • Auf kaggle.com erhältst du Zugriff auf Gemma.
  • Wählen Sie eine Colab-Laufzeit mit ausreichenden Ressourcen zum Ausführen des Gemma 2B-Modells aus.
  • Generieren und konfigurieren Sie einen Kaggle-Nutzernamen und einen API-Schlüssel.

Nachdem Sie die Gemma-Einrichtung abgeschlossen haben, fahren Sie mit dem nächsten Abschnitt fort. Dort legen Sie Umgebungsvariablen für Ihre Colab-Umgebung fest.

Umgebungsvariablen festlegen

Legen Sie Umgebungsvariablen für KAGGLE_USERNAME und KAGGLE_KEY fest.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Abhängigkeiten installieren

Installieren Sie Keras und KerasNLP.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q tensorflow-cpu
pip install -q -U keras-nlp tensorflow-hub
pip install -q -U keras>=3
pip install -q -U tensorflow-text

Backend auswählen

Keras ist eine Deep-Learning-API auf hoher Ebene mit mehreren Frameworks, die auf einfache und nutzerfreundliche Weise entwickelt wurde. Mit Keras 3 können Sie das Back-End auswählen: TensorFlow, JAX oder PyTorch. Alle drei funktionieren in dieser Anleitung.

import os

# Select JAX as the backend
os.environ["KERAS_BACKEND"] = "jax"

# Pre-allocate 100% of TPU memory to minimize memory fragmentation
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

Pakete importieren

Importieren Sie Keras und KerasNLP.

import keras
import keras_nlp

# for reproducibility
keras.utils.set_random_seed(42)

Modell instanziieren

KerasNLP bietet Implementierungen vieler beliebter Modellarchitekturen. In dieser Anleitung instanziieren Sie das Modell mit GemmaCausalLM, einem End-to-End-Gemma-Modell für kausale Language Models. Ein kausales Sprachmodell sagt das nächste Token basierend auf vorherigen Tokens voraus.

Instanziieren Sie das Modell mit der Methode from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_1.1_instruct_2b_en")
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'task.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'preprocessor.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...

Die Funktion GemmaCausalLM.from_preset() instanziiert das Modell aus einer voreingestellten Architektur und Gewichtungen. Im Code oben gibt der String "gemma_1.1_instruct_2b_en" die Voreinstellung für das Gemma 2B-Modell mit 2 Milliarden Parametern an. Gemma-Modelle mit den Parametern 7B, 9B und 27B sind ebenfalls verfügbar. Sie finden die Codestrings für Gemma-Modelle in der Liste der Modellvarianten auf Kaggle.

Verwenden Sie die Methode summary, um weitere Informationen zum Modell abzurufen:

gemma_lm.summary()

Wie Sie der Zusammenfassung entnehmen können, hat das Modell 2,5 Milliarden trainierbare Parameter.

Formatierungshilfefunktionen definieren

from IPython.display import Markdown
import textwrap

def display_chat(prompt, text):
  formatted_prompt = "<font size='+1' color='brown'>🙋‍♂️<blockquote>" + prompt + "</blockquote></font>"
  text = text.replace('•', '  *')
  text = textwrap.indent(text, '> ', predicate=lambda _: True)
  formatted_text = "<font size='+1' color='teal'>🤖\n\n" + text + "\n</font>"
  return Markdown(formatted_prompt+formatted_text)

def to_markdown(text):
  text = text.replace('•', '  *')
  return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))

Chatbot erstellen

Das auf die Gemma-Anweisung abgestimmte Modell „gemma_1.1_instruct_2b_en“ wurde optimiert, um die folgenden Rundentokens zu verstehen:

<start_of_turn>user\n  ... <end_of_turn>\n
<start_of_turn>model\n ... <end_of_turn>\n

In dieser Anleitung werden diese Tokens verwendet, um den Chatbot zu erstellen. Weitere Informationen zu Gemma-Kontrolltokens finden Sie unter Formatierungs- und Systemanleitung.

Chatassistent zum Verwalten des Unterhaltungsstatus erstellen

class ChatState():
  """
  Manages the conversation history for a turn-based chatbot
  Follows the turn-based conversation guidelines for the Gemma family of models
  documented at https://ai.google.dev/gemma/docs/formatting
  """

  __START_TURN_USER__ = "<start_of_turn>user\n"
  __START_TURN_MODEL__ = "<start_of_turn>model\n"
  __END_TURN__ = "<end_of_turn>\n"

  def __init__(self, model, system=""):
    """
    Initializes the chat state.

    Args:
        model: The language model to use for generating responses.
        system: (Optional) System instructions or bot description.
    """
    self.model = model
    self.system = system
    self.history = []

  def add_to_history_as_user(self, message):
      """
      Adds a user message to the history with start/end turn markers.
      """
      self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)

  def add_to_history_as_model(self, message):
      """
      Adds a model response to the history with start/end turn markers.
      """
      self.history.append(self.__START_TURN_MODEL__ + message + self.__END_TURN__)

  def get_history(self):
      """
      Returns the entire chat history as a single string.
      """
      return "".join([*self.history])

  def get_full_prompt(self):
    """
    Builds the prompt for the language model, including history and system description.
    """
    prompt = self.get_history() + self.__START_TURN_MODEL__
    if len(self.system)>0:
      prompt = self.system + "\n" + prompt
    return prompt

  def send_message(self, message):
    """
    Handles sending a user message and getting a model response.

    Args:
        message: The user's message.

    Returns:
        The model's response.
    """
    self.add_to_history_as_user(message)
    prompt = self.get_full_prompt()
    response = self.model.generate(prompt, max_length=1024)
    result = response.replace(prompt, "")  # Extract only the new response
    self.add_to_history_as_model(result)
    return result

Mit dem Modell chatten

Chatten Sie mit dem Modell.

chat = ChatState(gemma_lm)
message = "Tell me, in a few words,  how to compute all prime numbers up to 1000?"
display_chat(message, chat.send_message(message))

🙋‍♂️

Beschreiben Sie mir in wenigen Worten, wie man alle Primzahlen bis 1.000 berechnet.
🤖

Das Sieb der Eratosthenes ist eine weitverbreitete Methode zur Berechnung aller Primzahlen bis zu einem bestimmten Grenzwert. Dabei werden Vielfache jeder Primzahl iterativ markiert.

Setzen Sie die Unterhaltung fort.

message = "Now in Python! No numpy, please!"
display_chat(message, chat.send_message(message))

🙋‍♂️

Jetzt in Python! Kein Schnickschnack, bitte!
🤖

def prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True

message = "Thank you, it works! Can you explain the code in plain English?"
display_chat(message, chat.send_message(message))

🙋‍♂️

Vielen Dank, es funktioniert! Können Sie mir den Code in einfachem Englisch erklären?
🤖

Der bereitgestellte Code definiert eine Funktion prime, die prüft, ob eine bestimmte Zahl eine Primzahl ist oder nicht.

So gehts:

  • Die Funktion verwendet das einzelne Argument n, das die zu prüfende Zahl ist.

  • Zuerst wird geprüft, ob n kleiner oder gleich 1 ist. Falls ja, ist die Zahl keine Primzahl, sodass die Funktion False zurückgibt.

  • Dann durchläuft es eine Schleife, die durch Zahlen von 2 bis zur Quadratwurzel aus n iteriert.

  • Für jede Zahl i wird geprüft, ob n gleichmäßig durch i teilbar ist (d.h. n % i == 0).

  • Wenn n durch i teilbar ist, gibt die Funktion False zurück, da n nicht Primzahl sein kann, wenn sie einen Divisor hat.

  • Wenn die Schleife abgeschlossen wird, ohne für n Divisoren zu finden, gibt die Funktion True zurück und gibt damit an, dass n eine Primzahl ist.

Verwendungsbeispiel:

>>> prime(2)
True
>>> prime(3)
True
>>> prime(4)
False
>>> prime(5)
True

Vorteile dieses Codes:

  • Es ist ein einfacher und effizienter Algorithmus zur Ermittlung von Primzahlen.
  • Sie ist weit verbreitet in verschiedenen Informatik- und mathematischen Anwendungen.
message = "Great! Now add those explanations as comments in the code."
display_chat(message, chat.send_message(message))

🙋‍♂️

Sehr gut! Fügen Sie diese Erklärungen jetzt als Kommentare in den Code ein.
🤖

def prime(n):
    """
    Checks whether a given number is prime or not.

    Args:
        n: The number to check.

    Returns:
        True if n is prime, False otherwise.
    """

    # Check if n is less than or equal to 1.
    if n <= 1:
        return False

    # Iterate through numbers from 2 to the square root of n.
    for i in range(2, int(n**0.5) + 1):
        # Check if n is divisible by i.
        if n % i == 0:
            return False

    # If the loop completes without finding any divisors for n, then n is prime.
    return True

Testen Sie die generierte Antwort, indem Sie den generierten Code ausführen:

def is_prime(n):
  """
  Checks if a number is prime.

  Args:
    n: The number to check.

  Returns:
    True if n is prime, False otherwise.
  """

  # If n is less than or equal to 1, it is not prime.
  if n <= 1:
    return False

  # Iterate through all the numbers from 2 to the square root of n.
  for i in range(2, int(n**0.5) + 1):
    # If n is divisible by any of the numbers in the range from 2 to the square root of n, it is not prime.
    if n % i == 0:
      return False

  # If no divisors are found, n is prime.
  return True


# Initialize an empty list to store prime numbers.
primes = []

# Iterate through all the numbers from 2 to 1000.
for i in range(2, 1001):
  # If the number is prime, add it to the list.
  if is_prime(i):
    primes.append(i)

# Print the prime numbers.
print(primes)
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997]

Verwenden Sie die Methode get_history, um zu sehen, wie der gesamte Kontext von der Klasse Chat beibehalten wurde.

print(chat.get_history())
<start_of_turn>user
Tell me, in a few words,  how to compute all prime numbers up to 1000?<end_of_turn>
<start_of_turn>model
The Sieve of Eratosthenes is a widely used method to compute all prime numbers up to a given limit. It involves iteratively marking out multiples of each prime number.<end_of_turn>
<start_of_turn>user
Now in Python! No numpy, please!<end_of_turn>
<start_of_turn>model

```python
def prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True
```<end_of_turn>
<start_of_turn>user
Thank you, it works! Can you explain the code in plain English?<end_of_turn>
<start_of_turn>model
The provided code defines a function `prime` that checks whether a given number is prime or not.

**How it works:**

- The function takes a single argument, `n`, which is the number to check.


- It first checks if `n` is less than or equal to 1. If it is, the number is not prime, so the function returns `False`.


- It then enters a loop that iterates through numbers from 2 to the square root of `n`.


- For each number `i`, it checks if `n` is divisible evenly by `i` (i.e., `n % i == 0`).


- If `n` is divisible by `i`, the function returns `False` because `n` cannot be prime if it has a divisor.


- If the loop completes without finding any divisors for `n`, the function returns `True`, indicating that `n` is a prime number.


**Example Usage:**

```python
>>> prime(2)
True
>>> prime(3)
True
>>> prime(4)
False
>>> prime(5)
True
```

**Benefits of this Code:**

- It is a simple and efficient algorithm for finding prime numbers.
- It is widely used in various computer science and mathematical applications.<end_of_turn>
<start_of_turn>user
Great! Now add those explanations as comments in the code.<end_of_turn>
<start_of_turn>model
```python
def prime(n):
    """
    Checks whether a given number is prime or not.

    Args:
        n: The number to check.

    Returns:
        True if n is prime, False otherwise.
    """

    # Check if n is less than or equal to 1.
    if n <= 1:
        return False

    # Iterate through numbers from 2 to the square root of n.
    for i in range(2, int(n**0.5) + 1):
        # Check if n is divisible by i.
        if n % i == 0:
            return False

    # If the loop completes without finding any divisors for n, then n is prime.
    return True
```<end_of_turn>

Zusammenfassung und weitere Informationen

In dieser Anleitung haben Sie erfahren, wie Sie mit Keras auf JAX mit dem abgestimmten Modell Gemma 2B Instruction chatten.

In diesen Leitfäden und Tutorials erfahren Sie mehr über Gemma: