与 Gemma 一起构建聊天机器人

在 ai.google.dev 上查看 在 Google Colab 中运行 在 Vertex AI 中打开 在 GitHub 上查看源代码

Gemma 等大语言模型 (LLM) 擅长生成信息丰富的回答,因此非常适合构建虚拟助理和聊天机器人。

通常来说,LLM 以无状态方式运行,这意味着它们缺乏固有的记忆力来存储过去的对话。每个提示或问题都是独立处理的,而忽略之前的互动。但是,自然对话的一个重要方面是保留之前互动时的上下文。为了克服这一限制并让 LLM 保持对话上下文,必须向 LLM 明确提供相关信息,例如向 LLM 提交的每个新提示中的对话历史记录(或相关部分)。

本教程介绍了如何使用 Gemma 的指令调优模型变体开发聊天机器人。

设置

Gemma 设置

要完成本教程,首先需要按照 Gemma 设置中的说明完成设置。Gemma 设置说明介绍了如何执行以下操作:

  • 在 kaggle.com 上访问 Gemma。
  • 选择具有足够资源的 Colab 运行时来运行 Gemma 2B 模型。
  • 生成并配置 Kaggle 用户名和 API 密钥。

完成 Gemma 设置后,请继续执行下一部分,您将为 Colab 环境设置环境变量。

设置环境变量

KAGGLE_USERNAMEKAGGLE_KEY 设置环境变量。

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

安装依赖项

安装 Keras 和 KerasNLP。

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q tensorflow-cpu
pip install -q -U keras-nlp tensorflow-hub
pip install -q -U keras>=3
pip install -q -U tensorflow-text

选择一个后端

Keras 是一个高级的多框架深度学习 API,旨在实现简洁易用。Keras 3 支持选择后端:TensorFlow、JAX 或 PyTorch。这三个选项都适用于本教程。

import os

# Select JAX as the backend
os.environ["KERAS_BACKEND"] = "jax"

# Pre-allocate 100% of TPU memory to minimize memory fragmentation
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

导入软件包

导入 Keras 和 KerasNLP。

import keras
import keras_nlp

# for reproducibility
keras.utils.set_random_seed(42)

实例化模型

KerasNLP 提供了许多热门模型架构的实现。在本教程中,您将使用 GemmaCausalLM 实例化此模型,它是用于因果语言建模的端到端 Gemma 模型。因果语言模型会根据上一个词元预测下一个词元。

使用 from_preset 方法实例化模型:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_1.1_instruct_2b_en")
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'task.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'preprocessor.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...

GemmaCausalLM.from_preset() 函数会根据预设架构和权重对模型进行实例化。在上面的代码中,字符串 "gemma_1.1_instruct_2b_en" 指定了预设的 Gemma 2B 模型,其中包含 20 亿个参数。您也可以使用具有 7B、9B 和 27B 参数的 Gemma 模型。您可以在 Kaggle 上的模型变体详情中找到 Gemma 模型的代码字符串。

使用 summary 方法获取有关模型的更多信息:

gemma_lm.summary()

从摘要中可以看出,该模型有 25 亿个可训练参数。

定义格式设置辅助函数

from IPython.display import Markdown
import textwrap

def display_chat(prompt, text):
  formatted_prompt = "<font size='+1' color='brown'>🙋‍♂️<blockquote>" + prompt + "</blockquote></font>"
  text = text.replace('•', '  *')
  text = textwrap.indent(text, '> ', predicate=lambda _: True)
  formatted_text = "<font size='+1' color='teal'>🤖\n\n" + text + "\n</font>"
  return Markdown(formatted_prompt+formatted_text)

def to_markdown(text):
  text = text.replace('•', '  *')
  return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))

构建聊天机器人

Gemma 指令微调模型 gemma_1.1_instruct_2b_en 经过微调,可理解以下回合令牌:

<start_of_turn>user\n  ... <end_of_turn>\n
<start_of_turn>model\n ... <end_of_turn>\n

本教程使用这些令牌构建聊天机器人。请参阅格式设置和系统说明,详细了解 Gemma 控制令牌。

创建聊天助手来管理对话状态

class ChatState():
  """
  Manages the conversation history for a turn-based chatbot
  Follows the turn-based conversation guidelines for the Gemma family of models
  documented at https://ai.google.dev/gemma/docs/formatting
  """

  __START_TURN_USER__ = "<start_of_turn>user\n"
  __START_TURN_MODEL__ = "<start_of_turn>model\n"
  __END_TURN__ = "<end_of_turn>\n"

  def __init__(self, model, system=""):
    """
    Initializes the chat state.

    Args:
        model: The language model to use for generating responses.
        system: (Optional) System instructions or bot description.
    """
    self.model = model
    self.system = system
    self.history = []

  def add_to_history_as_user(self, message):
      """
      Adds a user message to the history with start/end turn markers.
      """
      self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)

  def add_to_history_as_model(self, message):
      """
      Adds a model response to the history with start/end turn markers.
      """
      self.history.append(self.__START_TURN_MODEL__ + message + self.__END_TURN__)

  def get_history(self):
      """
      Returns the entire chat history as a single string.
      """
      return "".join([*self.history])

  def get_full_prompt(self):
    """
    Builds the prompt for the language model, including history and system description.
    """
    prompt = self.get_history() + self.__START_TURN_MODEL__
    if len(self.system)>0:
      prompt = self.system + "\n" + prompt
    return prompt

  def send_message(self, message):
    """
    Handles sending a user message and getting a model response.

    Args:
        message: The user's message.

    Returns:
        The model's response.
    """
    self.add_to_history_as_user(message)
    prompt = self.get_full_prompt()
    response = self.model.generate(prompt, max_length=1024)
    result = response.replace(prompt, "")  # Extract only the new response
    self.add_to_history_as_model(result)
    return result

与模型聊天

开始与模型聊天。

chat = ChatState(gemma_lm)
message = "Tell me, in a few words,  how to compute all prime numbers up to 1000?"
display_chat(message, chat.send_message(message))

🙋‍♂️

简单几句话,如何计算 1,000 以内的所有质数?
🤖?

伊拉托色尼珐三是一种广泛用于计算达到给定极限的所有质数的方法。它涉及以迭代方式标记每个质数的倍数。

继续对话。

message = "Now in Python! No numpy, please!"
display_chat(message, chat.send_message(message))

🙋‍♂️

现已支持 Python!别犹豫了!
🤖?

def prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True

message = "Thank you, it works! Can you explain the code in plain English?"
display_chat(message, chat.send_message(message))

🙋‍♂️

谢谢,可以!您能用通俗易懂的英语解释一下代码吗?
🤖?

提供的代码定义了一个函数 prime,用于检查给定数字是否为质数。

运作方式

  • 该函数接受一个参数 n,即要检查的数字。

  • 首先,它会检查 n 是否小于或等于 1。如果正确,则该数字不是质数,因此函数会返回 False

  • 然后,它会进入一个循环,该循环对从 2 到 n 的平方根的数字进行迭代。

  • 对于每个数字 i,它会检查 n 是否可被 i 整除(即n % i == 0).

  • 如果 n 可被 i 整除,则该函数会返回 False,因为如果 n 有除数,它不能是质数。

  • 如果循环完成,没有为 n 找到任何除数,则该函数会返回 True,表示 n 是质数。

用法示例

>>> prime(2)
True
>>> prime(3)
True
>>> prime(4)
False
>>> prime(5)
True

此代码的优势

  • 它是一种简单高效的查找质数的算法。
  • 它广泛应用于各种计算机科学和数学应用。
message = "Great! Now add those explanations as comments in the code."
display_chat(message, chat.send_message(message))

🙋‍♂️

太棒了!现在,将这些解释作为注释添加到代码中。
🤖?

def prime(n):
    """
    Checks whether a given number is prime or not.

    Args:
        n: The number to check.

    Returns:
        True if n is prime, False otherwise.
    """

    # Check if n is less than or equal to 1.
    if n <= 1:
        return False

    # Iterate through numbers from 2 to the square root of n.
    for i in range(2, int(n**0.5) + 1):
        # Check if n is divisible by i.
        if n % i == 0:
            return False

    # If the loop completes without finding any divisors for n, then n is prime.
    return True

通过运行生成的代码来测试生成的响应:

def is_prime(n):
  """
  Checks if a number is prime.

  Args:
    n: The number to check.

  Returns:
    True if n is prime, False otherwise.
  """

  # If n is less than or equal to 1, it is not prime.
  if n <= 1:
    return False

  # Iterate through all the numbers from 2 to the square root of n.
  for i in range(2, int(n**0.5) + 1):
    # If n is divisible by any of the numbers in the range from 2 to the square root of n, it is not prime.
    if n % i == 0:
      return False

  # If no divisors are found, n is prime.
  return True


# Initialize an empty list to store prime numbers.
primes = []

# Iterate through all the numbers from 2 to 1000.
for i in range(2, 1001):
  # If the number is prime, add it to the list.
  if is_prime(i):
    primes.append(i)

# Print the prime numbers.
print(primes)
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997]

使用 get_history 方法可查看 Chat 类如何保留所有上下文。

print(chat.get_history())
<start_of_turn>user
Tell me, in a few words,  how to compute all prime numbers up to 1000?<end_of_turn>
<start_of_turn>model
The Sieve of Eratosthenes is a widely used method to compute all prime numbers up to a given limit. It involves iteratively marking out multiples of each prime number.<end_of_turn>
<start_of_turn>user
Now in Python! No numpy, please!<end_of_turn>
<start_of_turn>model

```python
def prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True
```<end_of_turn>
<start_of_turn>user
Thank you, it works! Can you explain the code in plain English?<end_of_turn>
<start_of_turn>model
The provided code defines a function `prime` that checks whether a given number is prime or not.

**How it works:**

- The function takes a single argument, `n`, which is the number to check.


- It first checks if `n` is less than or equal to 1. If it is, the number is not prime, so the function returns `False`.


- It then enters a loop that iterates through numbers from 2 to the square root of `n`.


- For each number `i`, it checks if `n` is divisible evenly by `i` (i.e., `n % i == 0`).


- If `n` is divisible by `i`, the function returns `False` because `n` cannot be prime if it has a divisor.


- If the loop completes without finding any divisors for `n`, the function returns `True`, indicating that `n` is a prime number.


**Example Usage:**

```python
>>> prime(2)
True
>>> prime(3)
True
>>> prime(4)
False
>>> prime(5)
True
```

**Benefits of this Code:**

- It is a simple and efficient algorithm for finding prime numbers.
- It is widely used in various computer science and mathematical applications.<end_of_turn>
<start_of_turn>user
Great! Now add those explanations as comments in the code.<end_of_turn>
<start_of_turn>model
```python
def prime(n):
    """
    Checks whether a given number is prime or not.

    Args:
        n: The number to check.

    Returns:
        True if n is prime, False otherwise.
    """

    # Check if n is less than or equal to 1.
    if n <= 1:
        return False

    # Iterate through numbers from 2 to the square root of n.
    for i in range(2, int(n**0.5) + 1):
        # Check if n is divisible by i.
        if n % i == 0:
            return False

    # If the loop completes without finding any divisors for n, then n is prime.
    return True
```<end_of_turn>

总结和补充阅读材料

在本教程中,您学习了如何在 JAX 上使用 Keras 与 Gemma 2B 指令调优模型聊天。

查看以下指南和教程,详细了解 Gemma: