Gemma와 함께 챗봇 빌드

ai.google.dev에서 보기 Google Colab에서 실행 Vertex AI에서 열기 GitHub에서 소스 보기

Gemma와 같은 대규모 언어 모델 (LLM)은 유익한 응답을 생성하는 데 뛰어난 성능을 보여주므로 가상 어시스턴트와 챗봇을 빌드하는 데 이상적입니다.

일반적으로 LLM은 스테이트리스(Stateless) 방식으로 작동합니다. 즉, 과거 대화를 저장할 고유한 메모리가 없다는 의미입니다. 각 프롬프트 또는 질문은 이전 상호작용을 무시하고 독립적으로 처리됩니다. 그러나 자연스러운 대화의 중요한 측면은 이전 상호작용의 맥락을 유지하는 능력입니다. 이러한 제한을 극복하고 LLM이 대화 컨텍스트를 유지할 수 있도록 하려면 LLM에 표시되는 각각의 새 프롬프트에 대화 기록 (또는 관련 부분)과 같은 관련 정보가 명시적으로 제공되어야 합니다.

이 튜토리얼에서는 Gemma의 안내 조정된 모델 변형을 사용하여 챗봇을 개발하는 방법을 보여줍니다.

설정

Gemma 설정

이 튜토리얼을 완료하려면 먼저 Gemma 설정에서 설정 안내를 완료해야 합니다. Gemma 설정 안내에서는 다음 작업을 수행하는 방법을 보여줍니다.

  • kaggle.com에서 Gemma에 액세스해 보세요.
  • Gemma 2B 모델을 실행하기에 충분한 리소스가 있는 Colab 런타임을 선택합니다.
  • Kaggle 사용자 이름 및 API 키를 생성하고 구성합니다.

Gemma 설정을 완료한 후 다음 섹션으로 이동하여 Colab 환경의 환경 변수를 설정합니다.

환경 변수 설정하기

KAGGLE_USERNAMEKAGGLE_KEY의 환경 변수를 설정합니다.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

종속 항목 설치

Keras와 KerasNLP를 설치합니다.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q tensorflow-cpu
pip install -q -U keras-nlp tensorflow-hub
pip install -q -U keras>=3
pip install -q -U tensorflow-text

백엔드 선택

Keras는 단순성과 사용 편의성을 위해 설계된 높은 수준의 다중 프레임워크 딥 러닝 API입니다. Keras 3에서는 TensorFlow, JAX, PyTorch 중에서 백엔드를 선택할 수 있습니다. 이 튜토리얼에서는 세 가지 모두 사용할 수 있습니다.

import os

# Select JAX as the backend
os.environ["KERAS_BACKEND"] = "jax"

# Pre-allocate 100% of TPU memory to minimize memory fragmentation
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

패키지 가져오기

Keras와 KerasNLP를 가져옵니다.

import keras
import keras_nlp

# for reproducibility
keras.utils.set_random_seed(42)

모델 인스턴스화

KerasNLP는 널리 사용되는 여러 모델 아키텍처의 구현을 제공합니다. 이 튜토리얼에서는 인과 언어 모델링을 위한 엔드 투 엔드 Gemma 모델인 GemmaCausalLM를 사용하여 모델을 인스턴스화합니다. 인과 언어 모델은 이전 토큰을 기반으로 다음 토큰을 예측합니다.

from_preset 메서드를 사용하여 모델을 인스턴스화합니다.

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_1.1_instruct_2b_en")
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'task.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'preprocessor.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_1.1_instruct_2b_en/3' to your Colab notebook...

GemmaCausalLM.from_preset() 함수는 사전 설정된 아키텍처 및 가중치에서 모델을 인스턴스화합니다. 위 코드에서 "gemma_1.1_instruct_2b_en" 문자열은 매개변수가 20억 개 있는 Gemma 2B 모델의 사전 설정을 지정합니다. 7B, 9B, 27B 매개변수가 포함된 Gemma 모델도 사용할 수 있습니다. Gemma 모델의 코드 문자열은 Kaggle모델 변형 목록에서 찾을 수 있습니다.

summary 메서드를 사용하여 모델에 관한 자세한 정보를 가져옵니다.

gemma_lm.summary()

요약에서 볼 수 있듯이 모델에는 25억 개의 학습 가능한 매개변수가 있습니다.

서식 지정 도우미 함수 정의

from IPython.display import Markdown
import textwrap

def display_chat(prompt, text):
  formatted_prompt = "<font size='+1' color='brown'>🙋‍♂️<blockquote>" + prompt + "</blockquote></font>"
  text = text.replace('•', '  *')
  text = textwrap.indent(text, '> ', predicate=lambda _: True)
  formatted_text = "<font size='+1' color='teal'>🤖\n\n" + text + "\n</font>"
  return Markdown(formatted_prompt+formatted_text)

def to_markdown(text):
  text = text.replace('•', '  *')
  return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))

챗봇 빌드

Gemma 명령 조정 모델 gemma_1.1_instruct_2b_en는 다음 턴 토큰을 이해하도록 미세 조정됩니다.

<start_of_turn>user\n  ... <end_of_turn>\n
<start_of_turn>model\n ... <end_of_turn>\n

이 튜토리얼에서는 이러한 토큰을 사용하여 챗봇을 빌드합니다. Gemma 제어 토큰에 대한 자세한 내용은 형식 및 시스템 안내를 참조하세요.

대화 상태를 관리하는 채팅 도우미 만들기

class ChatState():
  """
  Manages the conversation history for a turn-based chatbot
  Follows the turn-based conversation guidelines for the Gemma family of models
  documented at https://ai.google.dev/gemma/docs/formatting
  """

  __START_TURN_USER__ = "<start_of_turn>user\n"
  __START_TURN_MODEL__ = "<start_of_turn>model\n"
  __END_TURN__ = "<end_of_turn>\n"

  def __init__(self, model, system=""):
    """
    Initializes the chat state.

    Args:
        model: The language model to use for generating responses.
        system: (Optional) System instructions or bot description.
    """
    self.model = model
    self.system = system
    self.history = []

  def add_to_history_as_user(self, message):
      """
      Adds a user message to the history with start/end turn markers.
      """
      self.history.append(self.__START_TURN_USER__ + message + self.__END_TURN__)

  def add_to_history_as_model(self, message):
      """
      Adds a model response to the history with start/end turn markers.
      """
      self.history.append(self.__START_TURN_MODEL__ + message + self.__END_TURN__)

  def get_history(self):
      """
      Returns the entire chat history as a single string.
      """
      return "".join([*self.history])

  def get_full_prompt(self):
    """
    Builds the prompt for the language model, including history and system description.
    """
    prompt = self.get_history() + self.__START_TURN_MODEL__
    if len(self.system)>0:
      prompt = self.system + "\n" + prompt
    return prompt

  def send_message(self, message):
    """
    Handles sending a user message and getting a model response.

    Args:
        message: The user's message.

    Returns:
        The model's response.
    """
    self.add_to_history_as_user(message)
    prompt = self.get_full_prompt()
    response = self.model.generate(prompt, max_length=1024)
    result = response.replace(prompt, "")  # Extract only the new response
    self.add_to_history_as_model(result)
    return result

모델과 채팅

모델과 채팅을 시작합니다.

chat = ChatState(gemma_lm)
message = "Tell me, in a few words,  how to compute all prime numbers up to 1000?"
display_chat(message, chat.send_message(message))

🙋‍♂️

1,000까지 모든 소수를 계산하는 방법을 간단하게 말해 보세요.

에라토스테네스 체는 주어진 극한까지 모든 소수를 계산하는 데 널리 사용되는 방법입니다. 여기에는 각 소수의 배수를 반복적으로 마크업하는 작업이 포함됩니다.

대화를 계속하세요.

message = "Now in Python! No numpy, please!"
display_chat(message, chat.send_message(message))

🙋‍♂️

이제 Python에서 사용할 수 있습니다. Numpy는 없습니다.

def prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True

message = "Thank you, it works! Can you explain the code in plain English?"
display_chat(message, chat.send_message(message))

🙋‍♂️

감사합니다. 쉬운 영어로 코드를 설명해 주실 수 있나요?
LB

제공된 코드는 주어진 숫자가 소수인지 확인하는 prime 함수를 정의합니다.

작동 방식:

  • 이 함수는 확인할 숫자인 단일 인수 n를 사용합니다.

  • 먼저 n가 1 이하인지 확인합니다. 그렇다면 숫자가 소수가 아니므로 이 함수는 False을 반환합니다.

  • 그런 다음 2부터 n의 제곱근까지 숫자를 반복하는 루프를 입력합니다.

  • 각 숫자 i에 대해 ni (즉, n % i == 0).

  • ni로 나눌 수 있는 경우 함수는 False을 반환합니다. n에 제수가 있으면 소수가 될 수 없기 때문입니다.

  • n의 제수를 찾지 않고 루프가 완료되면 함수는 True을 반환하여 n가 소수임을 나타냅니다.

사용 예:

>>> prime(2)
True
>>> prime(3)
True
>>> prime(4)
False
>>> prime(5)
True

본 강령의 이점:

  • 소수를 찾기 위한 간단하고 효율적인 알고리즘입니다.
  • 다양한 컴퓨터 공학 및 수학 애플리케이션에서 널리 사용됩니다.
message = "Great! Now add those explanations as comments in the code."
display_chat(message, chat.send_message(message))

🙋‍♂️

좋습니다. 이제 이러한 설명을 코드에 주석으로 추가하세요.

def prime(n):
    """
    Checks whether a given number is prime or not.

    Args:
        n: The number to check.

    Returns:
        True if n is prime, False otherwise.
    """

    # Check if n is less than or equal to 1.
    if n <= 1:
        return False

    # Iterate through numbers from 2 to the square root of n.
    for i in range(2, int(n**0.5) + 1):
        # Check if n is divisible by i.
        if n % i == 0:
            return False

    # If the loop completes without finding any divisors for n, then n is prime.
    return True

생성된 코드를 실행하여 생성된 응답을 테스트합니다.

def is_prime(n):
  """
  Checks if a number is prime.

  Args:
    n: The number to check.

  Returns:
    True if n is prime, False otherwise.
  """

  # If n is less than or equal to 1, it is not prime.
  if n <= 1:
    return False

  # Iterate through all the numbers from 2 to the square root of n.
  for i in range(2, int(n**0.5) + 1):
    # If n is divisible by any of the numbers in the range from 2 to the square root of n, it is not prime.
    if n % i == 0:
      return False

  # If no divisors are found, n is prime.
  return True


# Initialize an empty list to store prime numbers.
primes = []

# Iterate through all the numbers from 2 to 1000.
for i in range(2, 1001):
  # If the number is prime, add it to the list.
  if is_prime(i):
    primes.append(i)

# Print the prime numbers.
print(primes)
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997]

get_history 메서드를 사용하여 Chat 클래스가 모든 컨텍스트를 어떻게 유지했는지 확인합니다.

print(chat.get_history())
<start_of_turn>user
Tell me, in a few words,  how to compute all prime numbers up to 1000?<end_of_turn>
<start_of_turn>model
The Sieve of Eratosthenes is a widely used method to compute all prime numbers up to a given limit. It involves iteratively marking out multiples of each prime number.<end_of_turn>
<start_of_turn>user
Now in Python! No numpy, please!<end_of_turn>
<start_of_turn>model

```python
def prime(n):
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True
```<end_of_turn>
<start_of_turn>user
Thank you, it works! Can you explain the code in plain English?<end_of_turn>
<start_of_turn>model
The provided code defines a function `prime` that checks whether a given number is prime or not.

**How it works:**

- The function takes a single argument, `n`, which is the number to check.


- It first checks if `n` is less than or equal to 1. If it is, the number is not prime, so the function returns `False`.


- It then enters a loop that iterates through numbers from 2 to the square root of `n`.


- For each number `i`, it checks if `n` is divisible evenly by `i` (i.e., `n % i == 0`).


- If `n` is divisible by `i`, the function returns `False` because `n` cannot be prime if it has a divisor.


- If the loop completes without finding any divisors for `n`, the function returns `True`, indicating that `n` is a prime number.


**Example Usage:**

```python
>>> prime(2)
True
>>> prime(3)
True
>>> prime(4)
False
>>> prime(5)
True
```

**Benefits of this Code:**

- It is a simple and efficient algorithm for finding prime numbers.
- It is widely used in various computer science and mathematical applications.<end_of_turn>
<start_of_turn>user
Great! Now add those explanations as comments in the code.<end_of_turn>
<start_of_turn>model
```python
def prime(n):
    """
    Checks whether a given number is prime or not.

    Args:
        n: The number to check.

    Returns:
        True if n is prime, False otherwise.
    """

    # Check if n is less than or equal to 1.
    if n <= 1:
        return False

    # Iterate through numbers from 2 to the square root of n.
    for i in range(2, int(n**0.5) + 1):
        # Check if n is divisible by i.
        if n % i == 0:
            return False

    # If the loop completes without finding any divisors for n, then n is prime.
    return True
```<end_of_turn>

요약 및 추가 자료

이 튜토리얼에서는 JAX에서 Keras를 사용하여 Gemma 2B Instruction 조정 모델과 채팅하는 방법을 배웠습니다.

Gemma에 대해 자세히 알아보려면 다음 가이드와 튜토리얼을 확인하세요.