Xây dựng bot trò chuyện bằng Gemma

Xem trên ai.google.dev Chạy trong Google Colab Mở trong Vertex AI Xem nguồn trên GitHub

Các Mô hình ngôn ngữ lớn (LLM) như Gemma xuất sắc trong việc tạo ra các câu trả lời chứa thông tin hữu ích, khiến các mô hình này trở nên lý tưởng để xây dựng trợ lý ảo và bot trò chuyện.

Thông thường, các LLM hoạt động ở dạng phi trạng thái, tức là chúng thiếu bộ nhớ vốn có để lưu trữ các cuộc trò chuyện trước đây. Mỗi câu lệnh hoặc câu hỏi được xử lý độc lập, bỏ qua những hoạt động tương tác trước đó. Tuy nhiên, một khía cạnh quan trọng của cuộc trò chuyện tự nhiên là khả năng giữ lại ngữ cảnh từ các tương tác trước đó. Để khắc phục hạn chế này và cho phép các LLM duy trì ngữ cảnh trò chuyện, các LLM phải được cung cấp rõ ràng thông tin liên quan như lịch sử trò chuyện (hoặc các phần thích hợp) vào mỗi câu lệnh mới được trình bày cho LLM.

Phần hướng dẫn này chỉ cho bạn cách phát triển một bot trò chuyện bằng cách sử dụng biến thể mô hình đã điều chỉnh theo hướng dẫn của Gemma.

Thiết lập

Thiết lập Gemma

Để hoàn tất hướng dẫn này, trước tiên, bạn cần phải hoàn thành hướng dẫn thiết lập trong phần thiết lập Gemma. Hướng dẫn thiết lập Gemma chỉ cho bạn cách thực hiện những việc sau:

  • Truy cập vào Gemma trên kaggle.com.
  • Chọn một môi trường thời gian chạy Colab có đủ tài nguyên để chạy mô hình Gemma 2B.
  • Tạo và định cấu hình tên người dùng Kaggle và khoá API.

Sau khi thiết lập xong Gemma, hãy chuyển sang phần tiếp theo. Tại đây, bạn sẽ thiết lập các biến môi trường cho môi trường Colab của mình.

Đặt các biến môi trường

Thiết lập các biến môi trường cho KAGGLE_USERNAMEKAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Cài đặt phần phụ thuộc

Cài đặt Keras và KerasNLP.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q tensorflow-cpu
pip install -q -U keras-nlp tensorflow-hub
pip install -q -U keras>=3
pip install -q -U tensorflow-text

Chọn một phần phụ trợ

Keras là một API học sâu cấp cao, đa khung, được thiết kế để mang lại trải nghiệm đơn giản và dễ sử dụng. Keras 3 cho phép bạn chọn phần phụ trợ: TensorFlow, JAX hoặc PyTorch. Cả ba đều phù hợp với hướng dẫn này.

import os

# Select JAX as the backend
os.environ["KERAS_BACKEND"] = "jax"

# Pre-allocate 100% of TPU memory to minimize memory fragmentation
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

Nhập gói

Nhập Keras và KerasNLP.

import keras
import keras_nlp

# for reproducibility
keras.utils.set_random_seed(42)

Tạo thực thể cho mô hình

KerasNLP cung cấp cách triển khai nhiều kiến trúc mô hình phổ biến. Trong hướng dẫn này, bạn sẽ tạo thực thể cho mô hình này bằng GemmaCausalLM, một mô hình Gemma toàn diện để lập mô hình ngôn ngữ quan hệ nhân quả. Mô hình ngôn ngữ nhân quả dự đoán mã thông báo tiếp theo dựa trên mã thông báo trước đó.

Tạo thực thể cho mô hình này bằng phương thức from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_1.1_instruct_2b_en")
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'task.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'preprocessor.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...

Hàm GemmaCausalLM.from_preset() tạo thực thể cho mô hình từ một cấu trúc và trọng số đặt trước. Trong mã trên, chuỗi "gemma_1.1_instruct_2b_en" chỉ định giá trị đặt trước cho mô hình Gemma 2B với 2 tỷ tham số. Bạn cũng có thể sử dụng các mô hình Gemma với thông số 7B, 9B và 27B. Bạn có thể tìm thấy các chuỗi mã cho mô hình Gemma trong trang thông tin Biến thể mô hình trên Kaggle.

Sử dụng phương thức summary để biết thêm thông tin về mô hình:

gemma_lm.summary()

Như bạn có thể thấy trong bản tóm tắt, mô hình này có 2,5 tỷ tham số có thể huấn luyện.

Xác định các hàm trợ giúp định dạng

from IPython.display import Markdown
import textwrap

def display_chat(prompt, text):
  formatted_prompt = "<font size='+1' color='brown'>🙋‍♂️<blockquote>" + prompt + "</blockquote></font>"
  text = text.replace('•', '  *')
  text = textwrap.indent(text, '> ', predicate=lambda _: True)
  formatted_text = "<font size='+1' color='teal'>🤖\n\n" + text + "\n</font>"
  return Markdown(formatted_prompt+formatted_text)

def to_markdown(text):
  text = text.replace('•', '  *')
  return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))

Xây dựng bot trò chuyện

Mô hình tinh chỉnh hướng dẫn của Gemma gemma_1.1_instruct_2b_en được tinh chỉnh để có thể hiểu các mã thông báo rẽ sau:

<start_of_turn>user\n  ... <end_of_turn>\n
<start_of_turn>model\n ... <end_of_turn>\n

Hướng dẫn này sử dụng các mã thông báo này để tạo chatbot. Hãy tham khảo phần Định dạng và hướng dẫn hệ thống để biết thêm thông tin về mã thông báo điều khiển Gemma.

Tạo một trình trợ giúp trò chuyện để quản lý trạng thái cuộc trò chuyện

class ChatState():
  """
  Manages the conversation history for a turn-based chatbot
  Follows the turn-based conversation guidelines for the Gemma family of models
  documented at https://ai.google.dev/gemma/docs/formatting
  """

  __START_TURN_USER__ = "<start_of_turn>user\n"
  __START_TURN_MODEL__ = "<start_of_turn>model\n"
  __END_TURN__ = "<end_of_turn>\n"

  def __init__(self, model, system=""):
    """
    Initializes the chat state.

    Args:
        model: The language model to use for generating responses.
        system: (Optional) System instructions or bot description.
    """
    self.model = model
    self.system = system
    self.history = []

  def add_to_history_as_user(self, message):
      """
      Adds a user message to the history with start/end turn markers.
      """
      self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)

  def add_to_history_as_model(self, message):
      """
      Adds a model response to the history with start/end turn markers.
      """
      self.history.append(self.__START_TURN_MODEL__ + message + self.__END_TURN__)

  def get_history(self):
      """
      Returns the entire chat history as a single string.
      """
      return "".join([*self.history])

  def get_full_prompt(self):
    """
    Builds the prompt for the language model, including history and system description.
    """
    prompt = self.get_history() + self.__START_TURN_MODEL__
    if len(self.system)>0:
      prompt = self.system + "\n" + prompt
    return prompt

  def send_message(self, message):
    """
    Handles sending a user message and getting a model response.

    Args:
        message: The user's message.

    Returns:
        The model's response.
    """
    self.add_to_history_as_user(message)
    prompt = self.get_full_prompt()
    response = self.model.generate(prompt, max_length=1024)
    result = response.replace(prompt, "")  # Extract only the new response
    self.add_to_history_as_model(result)
    return result

Trò chuyện với người mẫu

Bắt đầu trò chuyện với mô hình đó.

chat = ChatState(gemma_lm)
message = "Tell me, in a few words,  how to compute all prime numbers up to 1000?"
display_chat(message, chat.send_message(message))

🙋‍♂️

Nói một cách ngắn gọn, làm thế nào để tính tất cả các số nguyên tố có độ lớn đến 1000?
🤖

Sàng Eratosthenes là phương pháp được sử dụng rộng rãi để tính tất cả các số nguyên tố lên đến giới hạn đã cho. Phương pháp này liên tục đánh dấu bội số của từng số nguyên tố.

Tiếp tục cuộc trò chuyện.

message = "Now in Python! No numpy, please!"
display_chat(message, chat.send_message(message))

🙋‍♂️

Nay đã có trên Python! Không sao cả!
🤖

def prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True

message = "Thank you, it works! Can you explain the code in plain English?"
display_chat(message, chat.send_message(message))

🙋‍♂️

Cảm ơn bạn, ứng dụng đã hoạt động! Bạn có thể giải thích đoạn mã này bằng tiếng Anh thuần tuý không?
🤖

Mã đã cung cấp xác định hàm prime kiểm tra xem một số đã cho có phải là số nguyên tố hay không.

Cách hoạt động:

  • Hàm này nhận một đối số duy nhất, n, là số cần kiểm tra.

  • Trước tiên, hàm này kiểm tra xem n có nhỏ hơn hay bằng 1. Nếu đúng, số này không phải là số nguyên tố, vì vậy hàm trả về False.

  • Sau đó, hàm này sẽ lặp lại quy trình lặp lại theo các số từ 2 đến căn bậc hai của n.

  • Đối với mỗi số i, hệ thống sẽ kiểm tra xem n có chia hết cho i hay không (tức là n % i == 0).

  • Nếu n chia hết cho i thì hàm sẽ trả về Falsen không thể là số nguyên tố nếu có số chia.

  • Nếu vòng lặp hoàn tất mà không tìm thấy ước số nào của n, hàm sẽ trả về True, cho biết n là số nguyên tố.

Ví dụ về cách sử dụng:

>>> prime(2)
True
>>> prime(3)
True
>>> prime(4)
False
>>> prime(5)
True

Lợi ích của Quy tắc này:

  • Đây là một thuật toán đơn giản và hiệu quả để tìm số nguyên tố.
  • Công cụ này được sử dụng rộng rãi trong nhiều ứng dụng toán học và khoa học máy tính.
message = "Great! Now add those explanations as comments in the code."
display_chat(message, chat.send_message(message))

🙋‍♂️

Vậy thì tuyệt quá! Bây giờ, hãy thêm những nội dung giải thích đó dưới dạng nhận xét vào mã.
🤖

def prime(n):
    """
    Checks whether a given number is prime or not.

    Args:
        n: The number to check.

    Returns:
        True if n is prime, False otherwise.
    """

    # Check if n is less than or equal to 1.
    if n <= 1:
        return False

    # Iterate through numbers from 2 to the square root of n.
    for i in range(2, int(n**0.5) + 1):
        # Check if n is divisible by i.
        if n % i == 0:
            return False

    # If the loop completes without finding any divisors for n, then n is prime.
    return True

Kiểm thử phản hồi được tạo bằng cách chạy mã đã tạo:

def is_prime(n):
  """
  Checks if a number is prime.

  Args:
    n: The number to check.

  Returns:
    True if n is prime, False otherwise.
  """

  # If n is less than or equal to 1, it is not prime.
  if n <= 1:
    return False

  # Iterate through all the numbers from 2 to the square root of n.
  for i in range(2, int(n**0.5) + 1):
    # If n is divisible by any of the numbers in the range from 2 to the square root of n, it is not prime.
    if n % i == 0:
      return False

  # If no divisors are found, n is prime.
  return True


# Initialize an empty list to store prime numbers.
primes = []

# Iterate through all the numbers from 2 to 1000.
for i in range(2, 1001):
  # If the number is prime, add it to the list.
  if is_prime(i):
    primes.append(i)

# Print the prime numbers.
print(primes)
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997]

Sử dụng phương thức get_history để xem cách lớp Chat giữ lại tất cả ngữ cảnh.

print(chat.get_history())
<start_of_turn>user
Tell me, in a few words,  how to compute all prime numbers up to 1000?<end_of_turn>
<start_of_turn>model
The Sieve of Eratosthenes is a widely used method to compute all prime numbers up to a given limit. It involves iteratively marking out multiples of each prime number.<end_of_turn>
<start_of_turn>user
Now in Python! No numpy, please!<end_of_turn>
<start_of_turn>model

```python
def prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True
```<end_of_turn>
<start_of_turn>user
Thank you, it works! Can you explain the code in plain English?<end_of_turn>
<start_of_turn>model
The provided code defines a function `prime` that checks whether a given number is prime or not.

**How it works:**

- The function takes a single argument, `n`, which is the number to check.


- It first checks if `n` is less than or equal to 1. If it is, the number is not prime, so the function returns `False`.


- It then enters a loop that iterates through numbers from 2 to the square root of `n`.


- For each number `i`, it checks if `n` is divisible evenly by `i` (i.e., `n % i == 0`).


- If `n` is divisible by `i`, the function returns `False` because `n` cannot be prime if it has a divisor.


- If the loop completes without finding any divisors for `n`, the function returns `True`, indicating that `n` is a prime number.


**Example Usage:**

```python
>>> prime(2)
True
>>> prime(3)
True
>>> prime(4)
False
>>> prime(5)
True
```

**Benefits of this Code:**

- It is a simple and efficient algorithm for finding prime numbers.
- It is widely used in various computer science and mathematical applications.<end_of_turn>
<start_of_turn>user
Great! Now add those explanations as comments in the code.<end_of_turn>
<start_of_turn>model
```python
def prime(n):
    """
    Checks whether a given number is prime or not.

    Args:
        n: The number to check.

    Returns:
        True if n is prime, False otherwise.
    """

    # Check if n is less than or equal to 1.
    if n <= 1:
        return False

    # Iterate through numbers from 2 to the square root of n.
    for i in range(2, int(n**0.5) + 1):
        # Check if n is divisible by i.
        if n % i == 0:
            return False

    # If the loop completes without finding any divisors for n, then n is prime.
    return True
```<end_of_turn>

Tóm tắt và đọc thêm

Trong hướng dẫn này, bạn đã tìm hiểu cách trò chuyện với mô hình được điều chỉnh theo Hướng dẫn của Gemma 2B bằng cách sử dụng Keras trên JAX.

Hãy xem các hướng dẫn sau để tìm hiểu thêm về Gemma: